pernicious Fire CLI bugfix
This commit is contained in:
@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
del trainer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
sequence_parallel_degree: int | None = 1
|
sequence_parallel_degree: int | None = None
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
@@ -1107,6 +1107,9 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_config(cls, data):
|
def check_sequence_parallel_config(cls, data):
|
||||||
|
if data.get("sequence_parallel_degree") is None:
|
||||||
|
data["sequence_parallel_degree"] = 1
|
||||||
|
|
||||||
if data.get("sequence_parallel_degree") > 1:
|
if data.get("sequence_parallel_degree") > 1:
|
||||||
if not data.get("flash_attention"):
|
if not data.get("flash_attention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user