pernicious Fire CLI bugfix
This commit is contained in:
@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
del model, tokenizer, trainer
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = 1
|
||||
sequence_parallel_degree: int | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1107,6 +1107,9 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, data):
|
||||
if data.get("sequence_parallel_degree") is None:
|
||||
data["sequence_parallel_degree"] = 1
|
||||
|
||||
if data.get("sequence_parallel_degree") > 1:
|
||||
if not data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user