Compare commits

...

52 Commits

Author SHA1 Message Date
Dan Saunders
2daa94080c Merge branch 'main' into diff-transformer 2025-01-27 14:46:17 +00:00
Dan Saunders
0e9bfa6dee small fixes, improvements 2025-01-24 19:53:54 +00:00
Dan Saunders
ef38f10274 merging into main 2025-01-24 18:03:27 +00:00
Wing Lian
887513285d support for custom lr groups for non-embedding modules (#2213)
* support for custom lr groups for non-embedding modules

invert name check for group modules
include lr_groups in training args
additional conditional for creating optimizer
fix regular params as w weight decay
fix lookup and add docs

* address pr feedback
2025-01-24 12:56:28 -05:00
Wing Lian
20620771f1 Pretrain multipack (#2278)
* fix for pretrain with packing

* fix model name and loss expected

* make sure to check with micro batch size for pretraining

* change loss threshholds based on parametrization

* make tests smaller for CI

* fix pretrain packing

* fix pretrain packing test

* address pr feedback
2025-01-24 12:55:20 -05:00
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
NanoCode012
6086162488 chore(doc): improve explanation for *_steps and *_strategy (#2270) 2025-01-24 10:07:02 -05:00
mashdragon
b2774af66c Take split param from config in all load_dataset instances (#2281) 2025-01-24 10:06:50 -05:00
NanoCode012
74f9782fc3 chore(doc): fix explanation on gcs creds retrieval (#2272) 2025-01-24 10:05:58 -05:00
Wing Lian
8a7a0b07dc support for latest transformers release 4.48.1 (#2256) 2025-01-23 21:17:57 -05:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
35 changed files with 509 additions and 461 deletions

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -519,8 +519,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
train_on_split: validation train_on_split: validation
# loading from s3 or gcs # loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access # s3 creds will be loaded from the system default / gcs will attempt to load from gcloud creds, google metadata service, or anon
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs. - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above
... ...
# Loading Data From a Public URL # Loading Data From a Public URL

View File

@@ -4,7 +4,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"] ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"] CMD ["sleep", "infinity"]

View File

@@ -360,10 +360,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `"no"` to skip checkpoint saves eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_steps: # Leave empty to save at each epoch save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that

29
docs/lr_groups.qmd Normal file
View File

@@ -0,0 +1,29 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.1 transformers==4.48.1
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.2.1 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.13.0 trl==0.13.0

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys() message_fields = features[field_messages][0].keys()
message_field_role = None message_field_role = None
for key in ["from", "role"]: for key in ["from", "role"]:
if key in message_fields: if key in message_fields:

View File

@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
""" """
Evaluates a `transformers` model by first loading the dataset(s) specified in the Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, dataset_meta=dataset_meta) return evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -8,6 +8,7 @@ import click
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -222,6 +223,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main(): def main():
cli() cli()

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool): if isinstance(value, bool):
if value: if value:
cmd.append(f"--{key}") cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else: else:
cmd.extend([f"--{key}", str(value)]) cmd.extend([f"--{key}", str(value)])

View File

@@ -243,6 +243,10 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "Scale the learning rate for the embedding layers."}, metadata={"help": "Scale the learning rate for the embedding layers."},
) )
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field( embedding_lr: Optional[float] = field(
default=None, default=None,
metadata={"help": "absolute learning rate for the embedding layers."}, metadata={"help": "absolute learning rate for the embedding layers."},
@@ -293,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin. so it can't be used as a mixin.
""" """
@@ -461,11 +465,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
) )
return super()._wrap_model(model, training=training, dataloader=dataloader) return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in [ not in [
"optimi_adamw", "optimi_adamw",
@@ -479,59 +567,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
for name, param in opt_model.named_parameters(): opt_model, optimizer_kwargs
if not param.requires_grad: )
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None: if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -548,6 +590,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif ( elif (
self.args.embedding_lr_scale is not None self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None or self.args.embedding_lr is not None
or self.args.lr_groups is not None
): ):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -1079,6 +1122,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self.optimizer = None self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None: if self.args.loraplus_lr_ratio is None:
@@ -1664,6 +1708,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1879,6 +1924,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False: if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -4,7 +4,7 @@ import csv
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]: ) -> Optional[dict[str, float]]:
"""Helper function to evaluate a single dataset safely. """Helper function to evaluate a single dataset safely.
Args: Args:
@@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -1,308 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -0,0 +1,67 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -147,6 +147,14 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel): class SFTDataset(BaseModel):
"""SFT configuration subset""" """SFT configuration subset"""
@@ -475,6 +483,7 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None adam_beta1: Optional[float] = None

View File

@@ -191,7 +191,7 @@ def wrap_pretraining_dataset(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=max_tokens * batch_size, pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn, multipack_attn=cfg.pretrain_multipack_attn,
) )
encode = functools.partial( encode = functools.partial(
@@ -201,8 +201,6 @@ def wrap_pretraining_dataset(
max_seq_length=max_tokens, max_seq_length=max_tokens,
batch_size=batch_size, batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn, multipack_attn=cfg.pretrain_multipack_attn,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
) )
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
@@ -247,9 +245,7 @@ def encode_packed_pretraining(
examples: Dict[str, List], examples: Dict[str, List],
max_seq_length: int = 2048, max_seq_length: int = 2048,
batch_size: int = 4, batch_size: int = 4,
multipack_attn: Optional[bool] = False, multipack_attn: Optional[bool] = True,
group_size: int = 100000,
bin_size: int = 200,
) -> Dict[str, List]: ) -> Dict[str, List]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# tokenize all the examples # tokenize all the examples
@@ -260,6 +256,9 @@ def encode_packed_pretraining(
train_dataset, train_dataset,
max_seq_length, max_seq_length,
skip_position_ids=not multipack_attn, skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
) )
sampler = MultipackBatchSampler( sampler = MultipackBatchSampler(
@@ -267,8 +266,6 @@ def encode_packed_pretraining(
lengths=get_dataset_lengths(train_dataset), lengths=get_dataset_lengths(train_dataset),
batch_size=1, batch_size=1,
batch_max_len=batch_size * max_seq_length, batch_max_len=batch_size * max_seq_length,
group_size=group_size,
bin_size=bin_size,
drop_last=True, drop_last=True,
) )

View File

@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ConnectionError):
pass pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
else: else:
try: try:
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
else: else:
raise ValueError( raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file" "unhandled dataset load: local path exists, but is neither a directory or a file"
) )
elif ds_from_hub: elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset( ds = load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
elif config_dataset.path.startswith("https://"): elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(config_dataset.data_files, str):
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=fp, data_files=fp,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")

View File

@@ -380,23 +380,19 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError( raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \ "Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -816,6 +812,7 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

View File

@@ -310,19 +310,22 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing( def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
): ):
drop_long = partial(drop_long_seq, sequence_len=sequence_len) drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
load_from_cache_file=False,
) )
if skip_position_ids: if not skip_position_ids:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
desc="Add position_id column (Pretraining Sample Packing)", desc="Add position_id column (Pretraining Sample Packing)",
) )
if drop_attention_mask:
train_dataset = train_dataset.remove_columns("attention_mask")
return train_dataset return train_dataset

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)]) result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called assert mock.called
assert mock.call_args.args[0] == [ assert mock.call_args.args[0][:5] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
f"axolotl.cli.{command}", f"axolotl.cli.{command}",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -23,6 +23,7 @@ def test_build_command():
"--batch-size", "--batch-size",
"8", "8",
"--debug", "--debug",
"--nouse-fp16",
] ]

View File

@@ -63,6 +63,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -201,6 +203,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -223,8 +226,12 @@ class TestMultiGPULlama:
] ]
) )
loss_threshold = 2.3
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
) )
def test_dpo_qlora_ddp(self, temp_dir): def test_dpo_qlora_ddp(self, temp_dir):
@@ -275,6 +282,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -297,8 +305,12 @@ class TestMultiGPULlama:
] ]
) )
loss_threshold = 2.3
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
@with_temp_dir @with_temp_dir
def test_mistral_multipack(self, temp_dir): def test_mistral_multipack(self, temp_dir):

View File

@@ -3,8 +3,6 @@ import unittest
import pytest import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip( @pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers" reason="Unsloth integration will be broken going into latest transformers"
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
check_self_attn_is_patchable(), check_self_attn_is_patchable(),

View File

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists from .utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -28,19 +28,25 @@ class TestPretrainLlama:
"sample_packing", "sample_packing",
[True, False], [True, False],
) )
def test_pretrain(self, temp_dir, sample_packing): @pytest.mark.parametrize(
"pretrain_multipack_attn",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
if not sample_packing and pretrain_multipack_attn:
return
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": sample_packing, "sample_packing": sample_packing,
"pretrain_multipack_attn": pretrain_multipack_attn,
"dataset_processes": 1,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "pad_token": "<|endoftext|>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"pretraining_dataset": [ "pretraining_dataset": [
{ {
@@ -51,7 +57,7 @@ class TestPretrainLlama:
], ],
"max_steps": 5, "max_steps": 5,
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -60,6 +66,7 @@ class TestPretrainLlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"use_tensorboard": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -68,3 +75,12 @@ class TestPretrainLlama:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
loss_threshold = 3.5
if sample_packing and not pretrain_multipack_attn:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)

View File

@@ -1,25 +0,0 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_step_is_patchable(),
"HF transformers Trainer.training_step has changed and isn't patchable",
)
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_forward_is_patchable(),
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
)

View File

@@ -41,6 +41,7 @@ class TestPretrainingPacking(unittest.TestCase):
} }
], ],
"sample_packing": True, "sample_packing": True,
"pretrain_multipack_attn": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
"micro_batch_size": 2, "micro_batch_size": 2,
@@ -87,9 +88,11 @@ class TestPretrainingPacking(unittest.TestCase):
assert data["labels"].shape == torch.Size( assert data["labels"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len] [1, original_bsz * cfg.sequence_len]
) )
assert data["attention_mask"].shape == torch.Size( assert "attention_mask" not in data
[1, original_bsz * cfg.sequence_len] # FIXME add back once we fix packing unpad/pad with attention mask
) # assert data["attention_mask"].shape == torch.Size(
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1 idx += 1