pernicious Fire CLI bugfix

This commit is contained in:
Dan Saunders
2025-03-14 00:18:39 +00:00
parent 0ade60d455
commit 03027cf6bf
2 changed files with 8 additions and 8 deletions

View File

@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `do_train`.

View File

@@ -245,7 +245,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = 1
sequence_parallel_degree: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1107,6 +1107,9 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sequence_parallel_config(cls, data):
if data.get("sequence_parallel_degree") is None:
data["sequence_parallel_degree"] = 1
if data.get("sequence_parallel_degree") > 1:
if not data.get("flash_attention"):
raise ValueError(