diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 662b64896..0358ad4e6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1481,6 +1481,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + if self.cfg.accelerator_config: + training_arguments_kwargs[ + "accelerator_config" + ] = self.cfg.accelerator_config + training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 708d41972..32bb1f5b6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -77,6 +77,7 @@ class PretrainingDataset(BaseModel): split: Optional[str] = "train" text_column: Optional[str] = "text" type: Optional[str] = "pretrain" + trust_remote_code: Optional[bool] = False class UserDefinedPrompterType(BaseModel): @@ -118,6 +119,8 @@ class SFTDataset(BaseModel): roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None + trust_remote_code: Optional[bool] = False + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -158,6 +161,7 @@ class KTODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None + trust_remote_code: Optional[bool] = False class RLType(str, Enum): @@ -504,6 +508,8 @@ class AxolotlInputConfig( dataloader_prefetch_factor: Optional[int] = None dataloader_drop_last: Optional[bool] = None + accelerator_config: Optional[Dict[str, Any]] = None + remove_unused_columns: Optional[bool] = None push_dataset_to_hub: Optional[str] = None @@ -702,6 +708,24 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_pretraining_split_batches_accelerate(cls, data): + # alternatively set ACCELERATE_SPLIT_BATCHES=False + if data.get("pretraining_dataset"): + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + @model_validator(mode="before") @classmethod def check_gptq_w_revision(cls, data): diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py new file mode 100644 index 000000000..62fb63c47 --- /dev/null +++ b/tests/e2e/test_llama_pretrain.py @@ -0,0 +1,67 @@ +""" +E2E tests for llama pretrain +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPretrainLlama(unittest.TestCase): + """ + Test case for Llama models w pretraining + """ + + @with_temp_dir + def test_pretrain_w_sample_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": True, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "pretraining_dataset": [ + { + "path": "allenai/c4", + "name": "en", + "type": "pretrain", + } + ], + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index fb623a43d..5d517585f 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase): def test_packing_stream_dataset(self): # pylint: disable=duplicate-code dataset = load_dataset( - "c4", + "allenai/c4", "en", streaming=True, )["train"] @@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase): { "pretraining_dataset": [ { - "path": "c4", + "path": "allenai/c4", "name": "en", "type": "pretrain", }