Compare commits
14 Commits
llmcompres
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7610a02881 | ||
|
|
b0cd54bcb9 | ||
|
|
54960d4de0 | ||
|
|
ed922796b7 | ||
|
|
3dd9c3bf3f | ||
|
|
0ba7d362fa | ||
|
|
e4f73bc98e | ||
|
|
bcb59c70e2 | ||
|
|
6a3e6f8c53 | ||
|
|
fee3c13bb5 | ||
|
|
996fc124e5 | ||
|
|
e963990ad7 | ||
|
|
c3f2b1c5c2 | ||
|
|
6ba5c0ed2c |
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -4,6 +4,12 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
|
# Run the workflow only when one of these files changes
|
||||||
|
paths:
|
||||||
|
- '**/*.md' # any Markdown file
|
||||||
|
- '**/*.qmd' # any Quarto file
|
||||||
|
- '_quarto.yaml'
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
checks: write
|
checks: write
|
||||||
contents: write
|
contents: write
|
||||||
|
|||||||
90
.runpod/tests.json
Normal file
90
.runpod/tests.json
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
{
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"name": "quick_smoke_test_sft",
|
||||||
|
"input": {
|
||||||
|
"user_id": "user",
|
||||||
|
"model_id": "llama-test",
|
||||||
|
"run_id": "llama-test",
|
||||||
|
"credentials": {
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"hf_token": ""
|
||||||
|
},
|
||||||
|
"args": {
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"load_in_4bit": true,
|
||||||
|
"strict": false,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"output_dir": "./outputs/lora-out",
|
||||||
|
"sequence_len": 4096,
|
||||||
|
"sample_packing": true,
|
||||||
|
"eval_sample_packing": false,
|
||||||
|
"pad_to_sequence_len": true,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": true,
|
||||||
|
"lora_modules_to_save": [
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head"
|
||||||
|
],
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"train_on_inputs": false,
|
||||||
|
"group_by_length": false,
|
||||||
|
"bf16": "auto",
|
||||||
|
"tf32": true,
|
||||||
|
"gradient_checkpointing": true,
|
||||||
|
"logging_steps": 1,
|
||||||
|
"flash_attention": true,
|
||||||
|
"warmup_steps": 1,
|
||||||
|
"evals_per_epoch": 1,
|
||||||
|
"eval_max_new_tokens": 128,
|
||||||
|
"saves_per_epoch": 1,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>"
|
||||||
|
},
|
||||||
|
"max_steps": 20
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"timeout": 100000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"config": {
|
||||||
|
"gpuTypeId": "NVIDIA GeForce RTX 4090",
|
||||||
|
"gpuCount": 1,
|
||||||
|
"containerDiskInGb": 200,
|
||||||
|
"env": [
|
||||||
|
{
|
||||||
|
"key": "TOKENIZER",
|
||||||
|
"value": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "DISABLE_LOG_STATS",
|
||||||
|
"value": "true"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"allowedCudaVersions": [
|
||||||
|
"12.8",
|
||||||
|
"12.7",
|
||||||
|
"12.6",
|
||||||
|
"12.5",
|
||||||
|
"12.4"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ accelerate==1.6.0
|
|||||||
datasets==3.5.0
|
datasets==3.5.0
|
||||||
deepspeed>=0.15.4
|
deepspeed>=0.15.4
|
||||||
trl==0.17.0
|
trl==0.17.0
|
||||||
hf_xet==1.0.0
|
hf_xet==1.1.0
|
||||||
hqq==0.2.5
|
hqq==0.2.5
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -2,4 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
configure_logging()
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
|
|||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
def load_cfg(
|
||||||
|
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||||
|
) -> DictDefault:
|
||||||
"""
|
"""
|
||||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
various setup.
|
various setup.
|
||||||
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
Returns:
|
Returns:
|
||||||
`DictDefault` mapping configuration keys to values.
|
`DictDefault` mapping configuration keys to values.
|
||||||
"""
|
"""
|
||||||
config = check_remote_config(config)
|
if isinstance(config, (str, Path)):
|
||||||
if Path(config).is_dir():
|
config = check_remote_config(config)
|
||||||
config = choose_config(Path(config))
|
if Path(config).is_dir():
|
||||||
|
config = choose_config(Path(config))
|
||||||
|
|
||||||
# Load the config from the yaml file
|
# Load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
|
cfg.axolotl_config_path = config
|
||||||
|
else:
|
||||||
|
cfg = config
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
temp_file.write(yaml.dump(config.to_dict()))
|
||||||
|
temp_file.close()
|
||||||
|
cfg.axolotl_config_path = temp_file.name
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
else:
|
else:
|
||||||
cfg[k] = kwargs[k]
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
cfg.axolotl_config_path = config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_props = torch.cuda.get_device_properties("cuda")
|
device_props = torch.cuda.get_device_properties("cuda")
|
||||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
|||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import patch_optimized_env
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
cli_args: CLI arguments.
|
cli_args: CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from axolotl.cli.utils import (
|
|||||||
filter_none_kwargs,
|
filter_none_kwargs,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import patch_optimized_env
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -55,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
|
patch_optimized_env()
|
||||||
|
|
||||||
if cloud:
|
if cloud:
|
||||||
from axolotl.cli.cloud import do_cli_preprocess
|
from axolotl.cli.cloud import do_cli_preprocess
|
||||||
|
|
||||||
@@ -100,7 +102,7 @@ def train(
|
|||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
patch_optimized_env()
|
||||||
|
|
||||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||||
accelerate = False
|
accelerate = False
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import patch_optimized_env
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
|
|||||||
@@ -20,11 +20,9 @@ from transformers import (
|
|||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
|||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets, calling
|
Loads one or more training or evaluation datasets, calling
|
||||||
@@ -64,7 +64,8 @@ def load_datasets(
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
preprocess_iterable = (
|
preprocess_iterable = (
|
||||||
hasattr(cli_args, "iterable")
|
cli_args
|
||||||
|
and hasattr(cli_args, "iterable")
|
||||||
and cli_args.iterable is not None
|
and cli_args.iterable is not None
|
||||||
and cli_args.iterable
|
and cli_args.iterable
|
||||||
)
|
)
|
||||||
@@ -76,7 +77,7 @@ def load_datasets(
|
|||||||
preprocess_iterable=preprocess_iterable,
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if cli_args and (
|
||||||
cli_args.debug
|
cli_args.debug
|
||||||
or cfg.debug
|
or cfg.debug
|
||||||
or cli_args.debug_text_only
|
or cli_args.debug_text_only
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
self.cfg.max_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||||
|
|||||||
@@ -610,3 +610,15 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def compute_loss_context_manager(self):
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
|
||||||
|
stack = ExitStack()
|
||||||
|
|
||||||
|
stack.enter_context(super().compute_loss_context_manager())
|
||||||
|
stack.enter_context(OffloadActivations())
|
||||||
|
|
||||||
|
return stack
|
||||||
|
|||||||
@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
|
||||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
|
||||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.train import (
|
from axolotl.train import (
|
||||||
TrainDatasetMeta,
|
TrainDatasetMeta,
|
||||||
setup_model_and_tokenizer,
|
setup_model_and_tokenizer,
|
||||||
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,30 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3":
|
||||||
|
from axolotl.integrations.liger.models.qwen3 import (
|
||||||
|
apply_liger_kernel_to_qwen3,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_moe(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
|||||||
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
Liger FLCE for Qwen3. Based on transformers v4.51.3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
# if in training mode, don't materialize logits
|
||||||
|
if self.training and (labels is not None):
|
||||||
|
loss = LigerForCausalLMLoss(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
lm_head_weight=self.lm_head.weight,
|
||||||
|
labels=labels,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # if in inference mode materialize logits
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_liger_kernel_to_qwen3(
|
||||||
|
cross_entropy: bool = False,
|
||||||
|
fused_linear_cross_entropy: bool = False,
|
||||||
|
rms_norm: bool = False,
|
||||||
|
glu_activation: bool = False,
|
||||||
|
layer_norm: bool = False,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
"""
|
||||||
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||||
|
fused_linear_cross_entropy (bool):
|
||||||
|
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||||
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||||
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||||
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||||
|
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||||
|
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
cross_entropy and fused_linear_cross_entropy
|
||||||
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||||
|
|
||||||
|
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
|
||||||
|
|
||||||
|
if rms_norm:
|
||||||
|
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
||||||
|
|
||||||
|
if glu_activation:
|
||||||
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
||||||
|
|
||||||
|
if layer_norm:
|
||||||
|
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if fused_linear_cross_entropy:
|
||||||
|
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward
|
||||||
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||||
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
|
||||||
|
|
||||||
|
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
# if in training mode, don't materialize logits
|
||||||
|
if self.training and (labels is not None):
|
||||||
|
loss = LigerForCausalLMLoss(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
lm_head_weight=self.lm_head.weight,
|
||||||
|
labels=labels,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # if in inference mode materialize logits
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to(
|
||||||
|
loss.device
|
||||||
|
) # make sure to reside in the same device
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_liger_kernel_to_qwen3_moe(
|
||||||
|
cross_entropy: bool = False,
|
||||||
|
fused_linear_cross_entropy: bool = False,
|
||||||
|
rms_norm: bool = False,
|
||||||
|
glu_activation: bool = False,
|
||||||
|
layer_norm: bool = False,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
"""
|
||||||
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||||
|
fused_linear_cross_entropy (bool):
|
||||||
|
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||||
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||||
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||||
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||||
|
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||||
|
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
cross_entropy and fused_linear_cross_entropy
|
||||||
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||||
|
|
||||||
|
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
|
||||||
|
|
||||||
|
if rms_norm:
|
||||||
|
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
||||||
|
|
||||||
|
if glu_activation:
|
||||||
|
|
||||||
|
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||||
|
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
||||||
|
# clone config to avoid modifying the original
|
||||||
|
config = deepcopy(config)
|
||||||
|
if intermediate_size:
|
||||||
|
setattr(config, "intermediate_size", intermediate_size)
|
||||||
|
return LigerSwiGLUMLP(config, **kwargs)
|
||||||
|
|
||||||
|
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
|
||||||
|
|
||||||
|
if layer_norm:
|
||||||
|
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if fused_linear_cross_entropy:
|
||||||
|
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward
|
||||||
@@ -12,10 +12,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
|
"qwen3",
|
||||||
|
"qwen3_moe",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/trainer/__init__.py
Normal file
0
src/axolotl/monkeypatch/trainer/__init__.py
Normal file
@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
|
|||||||
SequenceParallelContextManager,
|
SequenceParallelContextManager,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -42,7 +41,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||||
"expandable_segments:True,roundup_power2_divisions:16"
|
"expandable_segments:True,roundup_power2_divisions:16"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_optimized_env():
|
||||||
|
"""
|
||||||
|
Patch environment variables to improve VRAM usage and increase download speed
|
||||||
|
"""
|
||||||
|
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||||
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
def resolve_dtype(cfg):
|
def resolve_dtype(cfg):
|
||||||
if (
|
if (
|
||||||
cfg.bf16 == "auto" and not cfg.use_ray
|
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
|
||||||
): # if we use ray we want to defer this check to the worker node
|
): # if we use ray we want to defer this check to the worker node
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||||
@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
|
|||||||
else:
|
else:
|
||||||
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
if cfg.fp16 is None:
|
if cfg.fp16 is None and not cfg.float16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
CheckpointPolicy,
|
||||||
|
checkpoint,
|
||||||
|
create_selective_checkpoint_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
@@ -18,3 +25,32 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
),
|
),
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
compute_intensive_ops = [
|
||||||
|
aten.mm.default,
|
||||||
|
aten.bmm.default,
|
||||||
|
aten.addmm.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def policy_fn(ctx, op, *args, **kwargs):
|
||||||
|
if op in compute_intensive_ops:
|
||||||
|
return CheckpointPolicy.MUST_SAVE
|
||||||
|
else:
|
||||||
|
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||||
|
|
||||||
|
|
||||||
|
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_w_policy(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
return checkpoint(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
use_reentrant=use_reentrant,
|
||||||
|
context_fn=context_fn,
|
||||||
|
)
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.len_across_ranks = None
|
self.len_across_ranks = None
|
||||||
|
|
||||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||||
LOG.warn(
|
LOG.warning(
|
||||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -512,10 +512,17 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_sample_packing_padding(cls, data):
|
def hint_sample_packing_padding(cls, data):
|
||||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
if data.get("sample_packing"):
|
||||||
LOG.warning(
|
pad_to_sequence_len = data.get("pad_to_sequence_len")
|
||||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
if pad_to_sequence_len is False:
|
||||||
)
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
)
|
||||||
|
elif pad_to_sequence_len is None:
|
||||||
|
LOG.info(
|
||||||
|
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||||
|
)
|
||||||
|
data["pad_to_sequence_len"] = True
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
|
num_completions_to_print: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
|
||||||
|
},
|
||||||
|
)
|
||||||
sync_ref_model: bool | None = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
elif cfg.fp16:
|
elif cfg.fp16:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
else:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||||
|
|
||||||
|
|
||||||
def prepare_opinionated_env(cfg):
|
def prepare_opinionated_env(cfg):
|
||||||
|
|||||||
@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
|
|||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": None,
|
"pad_to_sequence_len": False,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
|
|||||||
for record in self._caplog.records
|
for record in self._caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_packing_autoset(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": None,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.INFO):
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.pad_to_sequence_len is True
|
||||||
|
|
||||||
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
||||||
"""
|
"""
|
||||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||||
|
|||||||
Reference in New Issue
Block a user