From 03027cf6bf874ee4a41695681c902de1657e1f2f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 14 Mar 2025 00:18:39 +0000 Subject: [PATCH] pernicious Fire CLI bugfix --- src/axolotl/cli/train.py | 11 ++++------- src/axolotl/utils/schemas/config.py | 5 ++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e991105e6..6cc7c7701 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: +def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): """ Trains a `transformers` model by first loading the dataset(s) specified in the `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin @@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + del model, tokenizer, trainer + plugin_manager = PluginManager.get_instance() - - del model - del tokenizer - del trainer - plugin_manager.post_train_unload(cfg) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): """ Parses `axolotl` config, CLI args, and calls `do_train`. diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3a8e73a4d..eb9792295 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -245,7 +245,7 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) - sequence_parallel_degree: int | None = 1 + sequence_parallel_degree: int | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -1107,6 +1107,9 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sequence_parallel_config(cls, data): + if data.get("sequence_parallel_degree") is None: + data["sequence_parallel_degree"] = 1 + if data.get("sequence_parallel_degree") > 1: if not data.get("flash_attention"): raise ValueError(