From 608a2f3180eceebb39c73154332bb209d41469fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 13:21:03 -0400 Subject: [PATCH 01/13] bump transformers for updated llama 3.1 (#1778) * bump transformers for updated llama 3.1 * bump for patch fix --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b2aac0dd0..a54a42ad9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf +transformers==4.43.1 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 From e6b299dd79f75537f3e247a3d22dfed5d3885bfb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 19:54:15 -0400 Subject: [PATCH 02/13] bump flash attention to 2.6.2 (#1781) [skip ci] --- requirements.txt | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index a54a42ad9..ec571570b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ fire PyYAML>=6.0 requests datasets==2.19.1 -flash-attn==2.6.1 +flash-attn==2.6.2 sentencepiece wandb einops diff --git a/setup.py b/setup.py index 9e6f34ad8..ceba63669 100644 --- a/setup.py +++ b/setup.py @@ -80,10 +80,10 @@ setup( dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.6.1", + "flash-attn==2.6.2", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", From fe250ada78ff3d5404e053f2ae050d66f3943248 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 19:54:28 -0400 Subject: [PATCH 03/13] fix fsdp loading of models, esp 70b (#1780) --- src/axolotl/utils/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 339195df7..436b31fef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -619,7 +619,7 @@ def load_model( and not cfg.trust_remote_code and not cfg.gptq ): - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True if "device_map" in model_kwargs: del model_kwargs["device_map"] @@ -701,7 +701,7 @@ def load_model( **model_kwargs, ) else: - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True if "device_map" in model_kwargs: From 6a9cfec2227935393bcfc0fbe324ef6232c520ec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 21:22:16 -0400 Subject: [PATCH 04/13] add support for simpo via cpo trainer (#1772) * add support for simpo via cpo trainer * add cpo_alpha / sft_weight from the paper * make sure to use the right builder for simpo --- src/axolotl/core/trainer_builder.py | 47 +++++++++++++++++-- .../config/models/input/v0_4_1/__init__.py | 3 ++ src/axolotl/utils/trainer.py | 2 +- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 616b1d4eb..9a12c5a06 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,16 @@ from transformers import ( ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer +from trl import ( + CPOConfig, + CPOTrainer, + DPOConfig, + DPOTrainer, + KTOConfig, + KTOTrainer, + ORPOConfig, + ORPOTrainer, +) from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -265,6 +274,18 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): """ +@dataclass +class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): + """ + CPO config for CPO training + """ + + simpo_gamma: Optional[float] = field( + default=None, + metadata={"help": "simpo gamma parameter"}, + ) + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers @@ -985,6 +1006,14 @@ class AxolotlKTOTrainer(KTOTrainer): tag_names = ["axolotl", "kto"] +class AxolotlCPOTrainer(CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1707,6 +1736,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" + if self.cfg.rl_beta: + training_args_kwargs["beta"] = self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha @@ -1715,9 +1746,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = AxolotlDPOConfig if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha + + if self.cfg.rl == "simpo": + training_args_cls = AxolotlCPOConfig + training_args_kwargs["loss_type"] = "simpo" + training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma + if self.cfg.cpo_alpha is not None: + training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha + if self.cfg.rl == "orpo": training_args_cls = AxolotlORPOConfig - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len @@ -1725,7 +1763,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "kto": training_args_cls = AxolotlKTOConfig - training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1 training_args_kwargs["desirable_weight"] = ( self.cfg.kto_desirable_weight or 1.0 ) @@ -1771,7 +1808,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo"]: trainer_cls = AxolotlDPOTrainer - dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] # these aren't used for the ORPO trainer @@ -1785,6 +1821,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in ["kto"]: trainer_cls = AxolotlKTOTrainer trainer_cls_args = [self.model] + elif self.cfg.rl in ["simpo"]: + trainer_cls = AxolotlCPOTrainer + trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") dpo_trainer = trainer_cls( diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7f30283af..7397c7c73 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -172,6 +172,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name + simpo = "simpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -644,6 +645,8 @@ class AxolotlInputConfig( orpo_alpha: Optional[float] = None rpo_alpha: Optional[float] = None + simpo_gamma: Optional[float] = None + cpo_alpha: Optional[float] = None kto_desirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 65c2d424e..c5a71e689 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -425,7 +425,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "orpo", "kto"]: + if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] From 22680913f3ac4bf3410855210648f396cfd5c7d5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 Jul 2024 10:24:11 -0400 Subject: [PATCH 05/13] Bump deepspeed 20240727 (#1790) * pin deepspeed to 0.14.4 otherwise it doesn't play nice with trl * Add test to import to try to trigger import dependencies --- requirements.txt | 2 +- setup.py | 2 +- tests/e2e/test_imports.py | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 tests/e2e/test_imports.py diff --git a/requirements.txt b/requirements.txt index ec571570b..981a62558 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ transformers==4.43.1 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 -deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b +deepspeed==0.14.4 pydantic==2.6.3 addict fire diff --git a/setup.py b/setup.py index ceba63669..1d164e0a1 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ setup( "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ - "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", + "deepspeed==0.14.4", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py new file mode 100644 index 000000000..f186eaac4 --- /dev/null +++ b/tests/e2e/test_imports.py @@ -0,0 +1,20 @@ +""" +test module to import various submodules that have historically broken due to dependency issues +""" +import unittest + + +class TestImports(unittest.TestCase): + """ + Test class to import various submodules that have historically broken due to dependency issues + """ + + def test_import_causal_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFCausalTrainerBuilder, + ) + + def test_import_rl_trainer(self): + from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + HFRLTrainerBuilder, + ) From 94ba93259f421d438e53117267b17097c48cdd65 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 Jul 2024 07:25:54 -0400 Subject: [PATCH 06/13] various batch of fixes (#1785) * various batch of fixes * more tweaks * fix autoawq requirement for torch flexibility * simplify conditionals * multi-node fixes wip * bump transformers and include 405b qlora+fsdp yaml --- examples/llama-3/qlora-fsdp-405b.yaml | 62 +++++++ requirements.txt | 3 +- src/axolotl/cli/preprocess.py | 9 +- src/axolotl/common/architectures.py | 14 ++ src/axolotl/core/trainer_builder.py | 154 ++++++++++-------- .../prompt_strategies/dpo/chat_template.py | 4 +- src/axolotl/train.py | 27 ++- src/axolotl/utils/data/sft.py | 9 +- src/axolotl/utils/distributed.py | 4 + src/axolotl/utils/models.py | 42 ++++- src/axolotl/utils/trainer.py | 24 ++- 11 files changed, 253 insertions(+), 99 deletions(-) create mode 100644 examples/llama-3/qlora-fsdp-405b.yaml create mode 100644 src/axolotl/common/architectures.py diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml new file mode 100644 index 000000000..385b7f91d --- /dev/null +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -0,0 +1,62 @@ +base_model: meta-llama/Meta-Llama-3.1-405B +tokenizer_type: AutoTokenizer + +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out/qlora-llama3_1-405b + +adapter: qlora + +sequence_len: 1024 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD +special_tokens: + pad_token: <|finetune_right_pad_id|> diff --git a/requirements.txt b/requirements.txt index 981a62558..5825ee190 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers==4.43.1 +transformers==4.43.3 tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 @@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59 gradio==3.50.2 tensorboard python-dotenv==1.0.1 +autoawq>=0.2.5 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 5ec279d4b..e0dd7c2dc 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -2,6 +2,7 @@ CLI to run training on a model """ import logging +import warnings from pathlib import Path from typing import Union @@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): if parsed_cli_args.download: model_name = parsed_cfg.base_model - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + with warnings.catch_warnings(): + # there are a bunch of useless UserWarnings about + # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" + warnings.simplefilter("ignore") + with init_empty_weights(include_buffers=True): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) LOG.info( Fore.GREEN diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py new file mode 100644 index 000000000..7610b335a --- /dev/null +++ b/src/axolotl/common/architectures.py @@ -0,0 +1,14 @@ +""" +Common architecture specific constants +""" + +MOE_ARCH_BLOCK = { + "dbrx": "DbrxFFN", + "jamba": "JambaSparseMoeBlock", + "jetmoe": [ + "JetMoeMoA", + "JetMoeMoE", + ], + "mixtral": "MixtralSparseMoeBlock", + "qwen2_moe": "Qwen2MoeSparseMoeBlock", +} diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9a12c5a06..ff4804b10 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -8,6 +8,7 @@ import importlib import importlib.util import logging import math +import os import sys from abc import abstractmethod from collections import defaultdict @@ -28,7 +29,7 @@ from transformers import ( TrainerCallback, TrainingArguments, ) -from transformers.trainer_utils import seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import ( CPOConfig, @@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): ) -class AxolotlTrainer(Trainer): +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: AxolotlTrainingArguments + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler + + +class AxolotlTrainer(SchedulerMixin, Trainer): """ Extend the base Trainer for axolotl helpers """ @@ -404,68 +475,6 @@ class AxolotlTrainer(Trainer): return self.optimizer - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: @@ -830,6 +839,14 @@ class AxolotlTrainer(Trainer): for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) + def _save_checkpoint(self, model, trial, metrics=None): + # make sure the checkpoint dir exists, since trainer is flakey + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + os.makedirs(output_dir, exist_ok=True) + return super()._save_checkpoint(model, trial, metrics=metrics) + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -929,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer): return self.lr_scheduler -class AxolotlDPOTrainer(DPOTrainer): +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers """ @@ -990,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer): return res -class AxolotlORPOTrainer(ORPOTrainer): +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ Extend the base ORPOTrainer for axolotl helpers """ @@ -998,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer): tag_names = ["axolotl", "orpo"] -class AxolotlKTOTrainer(KTOTrainer): +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ Extend the base KTOTrainer for axolotl helpers """ @@ -1006,7 +1023,7 @@ class AxolotlKTOTrainer(KTOTrainer): tag_names = ["axolotl", "kto"] -class AxolotlCPOTrainer(CPOTrainer): +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): """ Extend the base CPOTrainer for axolotl helpers """ @@ -1750,6 +1767,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "simpo": training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" + training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 4f2f14098..e0e5eb129 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -62,7 +62,7 @@ def default( tokenize=False, ) chosen_strip_index = result["chosen"].find(chosen["content"]) - result["chosen"] = result["chosen"][chosen_strip_index:] + result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["rejected"] = tokenizer.apply_chat_template( [rejected], @@ -71,7 +71,7 @@ def default( tokenize=False, ) rejected_strip_index = result["rejected"].find(rejected["content"]) - result["rejected"] = result["rejected"][rejected_strip_index:] + result["rejected"] = result["rejected"][rejected_strip_index:].rstrip() return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5ba5aed56..b8890d4f7 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -212,26 +212,23 @@ def train( elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() - unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) + trainer.save_model(cfg.output_dir) # the trainer saved a model.safetensors file in the output directory, - # but it is a proxy model and should be deleted - if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")): + # but it is most likely a proxy model and if so, should be deleted + maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")) + maybe_sharded = os.path.exists( + os.path.join(cfg.output_dir, "model.safetensors.index.json") + ) + + if maybe_proxy and maybe_sharded: LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}") LOG.info("This is a proxy model and should be deleted") - os.remove(os.path.join(cfg.output_dir, "model.safetensors")) + try: + os.remove(os.path.join(cfg.output_dir, "model.safetensors")) + except FileNotFoundError: + pass - # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if - # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or - # `zero3_save_16bit_model` is True in DeepSpeed Plugin. - # For Zero Stages 1 and 2, models are saved as usual in the output directory. - # The model name saved is `pytorch_model.bin` - unwrapped_model.save_pretrained( - cfg.output_dir, - is_main_process=trainer.accelerator.is_main_process, - save_function=trainer.accelerator.save, - state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), - ) elif cfg.local_rank == 0: if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index bbea1987f..2e923057d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -42,7 +42,7 @@ from axolotl.prompters import ( from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl") def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: - with zero_first(is_main_process()): + with zero_first(is_local_main_process()): if cfg.test_datasets: train_dataset, _, prompters = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" @@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets( # pylint: disable=duplicate-code if dataset: + # This is for the case where we already loaded a pretokenized dataset from the hub ... elif ( cfg.dataset_prepared_path @@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets( def for_d_in_datasets(dataset_configs): for dataset in dataset_configs: if dataset.name and isinstance(dataset.name, list): + # load_dataset doesn't properly handle multiple named configurations + # at the same time for a given dataset for name in dataset.name: yield DictDefault({**dataset, "name": name}) else: @@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets( ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: + # this is just a basic check to see if the path is a + # valid HF dataset that's loadable load_dataset( config_dataset.path, name=config_dataset.name, diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index ecb1bcc9e..4444a20c9 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -44,6 +44,10 @@ def is_main_process(): return dist.get_rank() == 0 +def is_local_main_process(): + return PartialState().is_main_process + + def get_world_size(): return int(os.getenv("WORLD_SIZE", "1")) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 436b31fef..8a50631ef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,6 +29,7 @@ from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, AutoTokenizer, + AwqConfig, BitsAndBytesConfig, GPTQConfig, PreTrainedModel, @@ -36,6 +37,7 @@ from transformers import ( # noqa: F401 ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -510,7 +512,25 @@ def load_model( model_kwargs["quantization_config"] = GPTQConfig( **model_config.quantization_config ) - if cfg.adapter == "qlora" and cfg.load_in_4bit: + if ( + cfg.adapter in ["qlora", "lora"] + and hasattr(model_config, "quantization_config") + and model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if model_config.quantization_config["quant_method"] == "gptq": + model_kwargs["quantization_config"] = GPTQConfig( + **model_config.quantization_config + ) + elif model_config.quantization_config["quant_method"] == "awq": + model_kwargs["quantization_config"] = AwqConfig( + **model_config.quantization_config + ) + elif model_config.quantization_config["quant_method"] == "bitsandbytes": + model_kwargs["quantization_config"] = BitsAndBytesConfig( + **model_config.quantization_config + ) + elif cfg.adapter == "qlora" and cfg.load_in_4bit: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -785,12 +805,14 @@ def load_model( set_z3_leaf_modules, ) - if cfg.model_config_type == "mixtral": - moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") - set_z3_leaf_modules(model, [moe_block]) - elif cfg.model_config_type == "dbrx": - moe_block = get_module_class_from_name(model, "DbrxFFN") - set_z3_leaf_modules(model, [moe_block]) + if cfg.model_config_type in MOE_ARCH_BLOCK: + set_z3_leaf_modules( + model, + [ + get_module_class_from_name(model, module_name) + for module_name in MOE_ARCH_BLOCK[cfg.model_config_type] + ], + ) if cfg.model_config_type == "qwen" and cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled @@ -804,6 +826,9 @@ def load_model( # make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if is_deepspeed_zero3_enabled(): + skip_prepare_model_for_kbit_training = True + if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable( @@ -838,6 +863,9 @@ def load_model( else: model, lora_config = load_adapter(model, cfg, cfg.adapter) + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True + if ( cfg.ddp and not load_in_8bit diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c5a71e689..bb9624051 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,4 +1,5 @@ """Module containing the Trainer class and related functions""" +import json import math import os import random @@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): return total_num_steps +def setup_deepspeed_env(cfg, stage=None): + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + if cfg.bf16: + os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" + elif cfg.fp16: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" + if stage: + os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) + if stage == 3: + os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + + def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" if cfg.fsdp_config.fsdp_activation_checkpointing: @@ -415,8 +429,14 @@ def prepare_optim_env(cfg): if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + stage = None + # check if the cfg.deepspeed is a file + if os.path.isfile(cfg.deepspeed): + # parse with json + with open(cfg.deepspeed, "r", encoding="utf-8") as fin: + deepspeed_config = json.load(fin) + stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) + setup_deepspeed_env(cfg, stage=stage) if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" From 55cc214c767741e83ee7b346e5e13e6c03b7b9fa Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Sun, 28 Jul 2024 21:48:57 -0400 Subject: [PATCH 07/13] Add flexible configuration options for `chat_template` dataset training (#1756) * Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * chore: lint * Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json. --------- Co-authored-by: Wing Lian --- .../prompt_strategies/chat_template.py | 313 ++++++- .../config/models/input/v0_4_1/__init__.py | 4 + .../prompt_strategies/test_chat_templates.py | 884 +++++++++++++++++- tests/prompt_strategies/test_sharegpt.py | 2 + 4 files changed, 1111 insertions(+), 92 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8c7a8dd4f..f9fa71f21 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -6,14 +6,16 @@ import logging from typing import Any, Dict, List, Optional from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import Prompter +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import chat_templates +# Configure the logger +logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") class ChatTemplatePrompter(Prompter): - """prompter for HF chat templates""" + """Prompter for HF chat templates""" def __init__( self, @@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter): max_length=2048, message_field_role: str = "from", message_field_content: str = "value", + message_field_training: str = "train", + message_field_training_detail: str = "train_detail", roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter): } self.message_field_role = message_field_role self.message_field_content = message_field_content + self.message_field_training = message_field_training + self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer self.chat_template = chat_template self.max_length = max_length @@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter): { "role": self.roles[t[self.message_field_role]], "content": t[self.message_field_content], + "training": t.get(self.message_field_training, None), } for t in conversation ] @@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter): chat_template=self.chat_template, ) + def get_offsets_for_train_detail( + self, text: str, train_details: List[Dict], mask_untrainable: bool = True + ) -> List[int]: + tokenized_output = self.tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) + tokens = tokenized_output.tokens() + token_offsets = tokenized_output["offset_mapping"] + + LOG.debug(f"Tokenizing text: {text}") + LOG.debug(f"Tokens: {tokens}") + # Adjust the end offsets. For some reason by default they are set to the same value as the start offsets. + for i in range(len(token_offsets) - 1): + token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1) + # Ensure the last token's end offset is set correctly + token_offsets[-1] = (token_offsets[-1][0], len(text) - 1) + LOG.debug(f"Token offsets: {token_offsets}") + + # Initialize all offsets as IGNORE_TOKEN_ID (not trained) + result = [IGNORE_TOKEN_ID] * len(token_offsets) + + # Adjust train_details to align with token boundaries + adjusted_train_details = self.adjust_train_details(train_details, token_offsets) + + for idx, (start, end) in enumerate(token_offsets): + for detail in adjusted_train_details: + # Check if the token is completely within the detail's range + if start >= detail["begin_offset"] and end <= detail["end_offset"]: + if detail["train"] or not mask_untrainable: + result[idx] = start + LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training") + else: + LOG.debug( + f"Token {idx} ({tokens[idx]}) marked as non-trainable" + ) + elif start < detail["end_offset"] and end > detail["begin_offset"]: + # Token partially overlaps with detail, always mark as non-trainable + LOG.debug( + f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable" + ) + + LOG.debug(f"Final result: {result}") + return result + + def adjust_train_details( + self, train_details: List[Dict], token_offsets: List[tuple] + ) -> List[Dict]: + adjusted_details = [] + for detail in train_details: + begin_offset = detail["begin_offset"] + end_offset = detail["end_offset"] + + # Find the first token that starts after or at the begin_offset + begin_token = next( + ( + i + for i, (t_start, t_end) in enumerate(token_offsets) + if t_start >= begin_offset + ), + len(token_offsets), + ) + if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset: + begin_token -= 1 + + # Find the last token that ends before or at the end_offset + end_token = next( + ( + i + for i in range(len(token_offsets) - 1, -1, -1) + if token_offsets[i][1] <= end_offset + ), + -1, + ) + if ( + end_token < len(token_offsets) - 1 + and token_offsets[end_token + 1][0] < end_offset + ): + end_token += 1 + + if begin_token <= end_token: + adjusted_begin = token_offsets[begin_token][0] + adjusted_end = token_offsets[end_token][1] + + if adjusted_begin != begin_offset or adjusted_end != end_offset: + LOG.warning( + f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})" + ) + + adjusted_details.append( + { + "begin_offset": adjusted_begin, + "end_offset": adjusted_end, + "train": detail["train"], + } + ) + else: + LOG.warning( + f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail." + ) + + return adjusted_details + class ChatTemplateStrategy(PromptTokenizingStrategy): """ @@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): _messages = "conversations" + def __init__( + self, + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=None, + train_on_eos="last", + ): + super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) + self.roles_to_train = roles_to_train if roles_to_train is not None else [] + self.train_on_eos = train_on_eos + @property def messages(self): return self._messages @@ -79,62 +201,169 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self._messages = messages def tokenize_prompt(self, prompt): - turns = self.get_conversation_thread(prompt) - prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True) + turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) + labels = [IGNORE_TOKEN_ID] * len(input_ids) - if not self.train_on_inputs: - user_prompt_len = len(prompt_ids) - labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] - else: - labels = input_ids + last_eos_idx = -1 + for index, turn in enumerate(turns): + role = turn.get(self.prompter.message_field_role) + content = turn.get(self.prompter.message_field_content) + train_turn = turn.get(self.prompter.message_field_training) + train_detail = turn.get(self.prompter.message_field_training_detail) - tokenized_prompt = { + LOG.debug( + f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" + ) + + should_train = ( + train_turn + if train_turn is not None + else bool(train_detail is not None) + if train_detail is not None + else self.train_on_inputs or role in self.roles_to_train + ) + + LOG.debug(f"Should train: {should_train}") + + turn_start_idx, turn_end_idx = self.find_turn( + conversation_ids=input_ids, turn=index, turn_content=turn + ) + + LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") + + if should_train and turn_start_idx != -1 and turn_end_idx != -1: + if train_detail: + token_offsets = self.prompter.get_offsets_for_train_detail( + content, train_detail + ) + LOG.debug(f"Token offsets: {token_offsets}") + for i, offset in enumerate(token_offsets): + if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len( + input_ids + ): + labels[turn_start_idx + i] = input_ids[turn_start_idx + i] + LOG.debug( + f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}" + ) + else: + labels[turn_start_idx:turn_end_idx] = input_ids[ + turn_start_idx:turn_end_idx + ] + LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}") + + LOG.debug(f"Labels after processing turn {index}: {labels}") + + # Handle EOS token + eos_idx = self.find_eos_token(input_ids, turn_end_idx) + if eos_idx == turn_end_idx: + last_eos_idx = eos_idx + if self.train_on_eos == "all" or ( + self.train_on_eos == "turn" and should_train + ): + labels[eos_idx] = input_ids[eos_idx] + LOG.debug(f"EOS token set for training at index {eos_idx}") + else: + LOG.debug( + f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}" + ) + + # Handle 'last' option for train_on_eos + if self.train_on_eos == "last" and last_eos_idx != -1: + labels[last_eos_idx] = input_ids[last_eos_idx] + LOG.debug(f"Last EOS token set for training at index {last_eos_idx}") + + LOG.debug(f"Final labels: {labels}") + + return { "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), } - return tokenized_prompt + def find_eos_token(self, input_ids, start_idx): + eos_token_id = self.tokenizer.eos_token_id + for i in range(start_idx, len(input_ids)): + if input_ids[i] == eos_token_id: + return i + return -1 + + def find_turn(self, conversation_ids, turn, turn_content): + """ + Locate the starting and ending indices of the specified turn in a conversation. + + Args: + conversation_ids (list[int]): Token IDs representing the conversation. + turn (int): The turn number to locate (based on EOS tokens). + turn_content (str): String containing the content of the turn. + + Returns: + tuple: (start_idx, end_idx) indices of the start and end of the turn content. + Returns (-1, -1) if the turn content is not found. + """ + content = turn_content.get(self.prompter.message_field_content, "") + content_ids = self.tokenizer.encode(content, add_special_tokens=False) + + eos_token_id = self.tokenizer.eos_token_id + eos_count = 0 + start_search_idx = 0 + + # Locate the starting index after the specified number of EOS tokens + for i, token_id in enumerate(conversation_ids): + if token_id == eos_token_id: + eos_count += 1 + if eos_count == turn: + start_search_idx = ( + i + 1 + ) # Start searching after the specified turn's EOS token + break + + # Find the start index of the content within the conversation + start_idx = -1 + for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1): + if conversation_ids[i : i + len(content_ids)] == content_ids: + start_idx = i + break + + if start_idx != -1: + end_idx = start_idx + len(content_ids) + else: + end_idx = -1 + + return start_idx, end_idx def get_conversation_thread(self, prompt): return prompt[self.messages] def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - chat_template = ( - ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" - ) - message_field_role = ( - ds_cfg["message_field_role"] - if ds_cfg and "message_field_role" in ds_cfg - else "from" - ) - message_field_content = ( - ds_cfg["message_field_content"] - if ds_cfg and "message_field_content" in ds_cfg - else "value" - ) - roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None - drop_system_message = ( - ds_cfg["drop_system_message"] - if ds_cfg and "drop_system_message" in ds_cfg - else False - ) + ds_cfg = ds_cfg or {} + + prompter_params = { + "tokenizer": tokenizer, + "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "message_field_role": ds_cfg.get("message_field_role", "from"), + "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training_detail": ds_cfg.get( + "message_field_training_detail", "train_detail" + ), + "roles": ds_cfg.get("roles"), + "drop_system_message": ds_cfg.get("drop_system_message", False), + } + + strategy_params = { + "train_on_inputs": cfg.train_on_inputs, + "sequence_len": cfg.sequence_len, + "roles_to_train": ds_cfg.get("roles_to_train"), + "train_on_eos": ds_cfg.get("train_on_eos", "last"), + } strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - tokenizer, - chat_templates(chat_template), - message_field_role=message_field_role, - message_field_content=message_field_content, - roles=roles, - drop_system_message=drop_system_message, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, + ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ) - if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): + + if "field_messages" in ds_cfg and hasattr(strategy, "messages"): strategy.messages = ds_cfg["field_messages"] + return strategy diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7397c7c73..e92c79485 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -116,6 +116,10 @@ class SFTDataset(BaseModel): field_messages: Optional[str] = None message_field_role: Optional[str] = None message_field_content: Optional[str] = None + message_field_training: Optional[str] = None + message_field_training_detail: Optional[str] = None + roles_to_train: Optional[List[str]] = None + train_on_eos: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 7b58a1236..e2fc0f6a5 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -2,6 +2,7 @@ tests for chat_template prompt strategy """ +import logging import unittest import pytest @@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import ( ChatTemplateStrategy, load, ) +from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + @pytest.fixture(name="assistant_dataset") def fixture_assistant_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { "messages": [ - { - "role": "user", - "content": "hello", - }, - { - "role": "assistant", - "content": "hello", - }, - { - "role": "user", - "content": "goodbye", - }, - { - "role": "assistant", - "content": "goodbye", - }, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "goodbye"}, + {"role": "assistant", "content": "goodbye"}, ] } ] @@ -53,22 +45,28 @@ def fixture_sharegpt_dataset(): [ { "conversations": [ - { - "from": "human", - "value": "hello", - }, - { - "from": "gpt", - "value": "hello", - }, - { - "from": "human", - "value": "goodbye", - }, - { - "from": "gpt", - "value": "goodbye", - }, + {"from": "human", "value": "hello"}, + {"from": "gpt", "value": "hello"}, + {"from": "human", "value": "goodbye"}, + {"from": "gpt", "value": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="basic_dataset") +def fixture_basic_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "system", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "assistant", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "assistant", "value": "I'm doing well, thank you!"}, ] } ] @@ -77,19 +75,611 @@ def fixture_sharegpt_dataset(): @pytest.fixture(name="llama3_tokenizer") def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") - tokenizer.eos_token = "<|eot_id|>" + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") return tokenizer +class TestChatTemplateConfigurations: + """ + Test class for various configurations of ChatTemplateStrategy. + """ + + @staticmethod + def find_sublist(full_list, sub_list): + token_count = len(sub_list) + for index in range(len(full_list) - token_count + 1): + if full_list[index : index + token_count] == sub_list: + return index + return -1 + + def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Check the behavior of human inputs + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + labeled = all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ) + LOG.debug( + f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" + ) + + LOG.debug("Full labels: %s", labels) + LOG.debug("Full input_ids: %s", input_ids) + + def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=False") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Verify that human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + LOG.debug( + f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" + + def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with assistant only") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with all roles") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["human", "assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that all responses are labeled (except for special tokens) + all_responses = [ + "Hello", + "Hi there!", + "How are you?", + "I'm doing well, thank you!", + ] + for response in all_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with empty roles_to_train") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + train_on_eos="none", # Add this line + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + + # Verify that no labels are set when roles_to_train is empty + LOG.debug("Full labels: %s", labels) + assert all( + label == IGNORE_TOKEN_ID for label in labels + ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + + def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='all'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="all", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to be labeled" + + def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='turn'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + + eos_idx = start_idx + len(response_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert eos_idx < len( + input_ids + ), f"Could not find EOS token after '{response}'" + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token after assistant response '{response}' to be labeled" + + # Check that EOS tokens after human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + + eos_idx = start_idx + len(input_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token after human input '{input_text}' to not be labeled" + + def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='last'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="last", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + last_eos_idx = eos_indices[-1] + + # Check that only the last EOS token is labeled + for idx in eos_indices[:-1]: + assert ( + labels[idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {idx} to not be labeled" + assert ( + labels[last_eos_idx] != IGNORE_TOKEN_ID + ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + + def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='none'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to not be labeled" + + def test_drop_system_message(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with drop_system_message=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + + # Check if system message is not present in input_ids + system_message = "You are an AI assistant." + system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) + assert ( + self.find_sublist(input_ids, system_ids) == -1 + ), "Expected system message to be dropped" + + def test_custom_roles(self, llama3_tokenizer): + LOG.info("Testing with custom roles mapping") + custom_roles = { + "user": ["human", "user"], + "assistant": ["ai", "assistant"], + "system": ["context"], + } + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["ai"], + ) + + # Create a new dataset with modified role names + modified_conversations = [ + {"from": "context", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "ai", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "ai", "value": "I'm doing well, thank you!"}, + ] + + modified_dataset = Dataset.from_dict( + {"conversations": [modified_conversations]} + ) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Check if AI responses are labeled correctly + ai_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in ai_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find response '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for AI response '{response}' to be set" + + # Check if human messages are not labeled + human_messages = ["Hello", "How are you?"] + for message in human_messages: + message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, message_ids) + assert start_idx != -1, f"Could not find message '{message}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(message_ids)] + ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" + + def test_message_field_training(self, llama3_tokenizer): + LOG.info("Testing with message_field_training") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_training="train", + message_field_training_detail="train_detail", + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + ) + + # Create a new dataset with the train and train_detail fields + modified_conversation = [ + {"from": "system", "value": "You are an AI assistant.", "train": False}, + {"from": "human", "value": "Hello", "train": False}, + {"from": "assistant", "value": "Hello", "train": True}, + {"from": "human", "value": "How are you?", "train": True}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": False}, + {"begin_offset": 9, "end_offset": 18, "train": True}, + {"begin_offset": 19, "end_offset": 30, "train": False}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": False, + }, + {"from": "assistant", "value": "Hi there!", "train": True}, + ] + + modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Function to find all occurrences of a sublist + def find_all_sublists(full_list, sub_list): + indices = [] + for index in range(len(full_list) - len(sub_list) + 1): + if full_list[index : index + len(sub_list)] == sub_list: + indices.append(index) + return indices + + # Keep track of which occurrences we've processed + processed_occurrences = {} + # Check if messages are labeled correctly based on train or train_detail + for i, turn in enumerate(modified_conversation): + turn_tokens = llama3_tokenizer.encode( + turn["value"], add_special_tokens=False + ) + occurrences = find_all_sublists(input_ids, turn_tokens) + turn_key = turn["value"] + if turn_key not in processed_occurrences: + processed_occurrences[turn_key] = 0 + current_occurrence = processed_occurrences[turn_key] + + if current_occurrence >= len(occurrences): + assert ( + False + ), f"Not enough occurrences found for message: {turn['value']}" + + start_idx = occurrences[current_occurrence] + processed_occurrences[turn_key] += 1 + end_idx = start_idx + len(turn_tokens) + + LOG.debug( + f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" + ) + + if "train_detail" in turn: + # Get token offsets + tokenized_output = llama3_tokenizer( + turn["value"], return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokenized_output["offset_mapping"] + + # Adjust token offsets as done in the implementation + for i in range(len(token_offsets) - 1): + token_offsets[i] = ( + token_offsets[i][0], + token_offsets[i + 1][0] - 1, + ) + token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) + + # Adjust train_details + adjusted_train_details = strategy.prompter.adjust_train_details( + turn["train_detail"], token_offsets + ) + + LOG.debug(f"Original train_details: {turn['train_detail']}") + LOG.debug(f"Adjusted train_details: {adjusted_train_details}") + + # Handle train_detail + token_offsets = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=False, + ) + token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=True, + ) + LOG.debug(f"Token offsets: {token_offsets_masked}") + + expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) + for i, offset in enumerate(token_offsets_masked): + if offset != IGNORE_TOKEN_ID: + expected_labels[i] = turn_tokens[i] + actual_labels = labels[ + start_idx : start_idx + len(token_offsets_masked) + ] + assert ( + actual_labels == expected_labels + ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + + for detail in adjusted_train_details: + # Find the token indices that correspond to the character offsets + detail_start = start_idx + next( + i + for i, offset in enumerate(token_offsets) + if offset >= detail["begin_offset"] + ) + detail_end = start_idx + next( + ( + i + for i, offset in enumerate(token_offsets) + if offset > detail["end_offset"] + ), + len(token_offsets), + ) + + detail_text = turn["value"][ + detail["begin_offset"] : detail["end_offset"] + 1 + ] + detail_labels = labels[detail_start:detail_end] + detail_input_ids = input_ids[detail_start:detail_end] + + LOG.debug( + f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" + ) + LOG.debug(f"Detail input_ids: {detail_input_ids}") + LOG.debug(f"Detail labels: {detail_labels}") + LOG.debug( + f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" + ) + LOG.debug( + f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" + ) + + if detail["train"]: + assert all( + label != IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + assert all( + label == IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + should_train = turn.get("train", False) + turn_labels = labels[start_idx:end_idx] + + LOG.debug(f"Should train: {should_train}") + LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") + LOG.debug( + f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" + ) + + if should_train: + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be set\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + else: + assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + + LOG.debug( + f"Processed turn: {turn['from']}, content: '{turn['value']}', " + f"start_idx: {start_idx}, end_idx: {end_idx}, " + f"labels: {labels[start_idx:end_idx]}" + ) + + LOG.debug(f"Final labels: {labels}") + LOG.debug(f"Final input_ids: {input_ids}") + + class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. """ def test_llama3_load(self, llama3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code + LOG.info("Loading llama-3 tokenizer with assistant dataset") strategy = load( llama3_tokenizer, DictDefault( @@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3: res = strategy.tokenize_prompt(assistant_dataset[0]) input_ids = res["input_ids"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot + 271, 15339, 128009, # assistant response eot 128006, 882, 128007, 271, 19045, 29474, 128009, 128006, 78191, 128007, 271, 19045, 29474, 128009, ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" def test_llama3(self, llama3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code + LOG.info("Testing llama-3 with assistant dataset") strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, @@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3: "system": ["system"], }, ), - llama3_tokenizer, - False, - 512, + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], ) strategy.messages = "messages" res = strategy.tokenize_prompt(assistant_dataset[0]) input_ids = res["input_ids"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot @@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3: 271, 19045, 29474, 128009, ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): + LOG.info("Testing llama-3 with assistant dataset including training data") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_role="role", + message_field_content="content", + message_field_training="training", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["assistant"], + ) + strategy.messages = "messages" + prompt_tokens = strategy.prompter.build_prompt( + assistant_dataset[0]["messages"], False + ) + prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False) + LOG.debug(f"Generated prompt: {prompt}") + res = strategy.tokenize_prompt(assistant_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + # fmt: off + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + assert labels == expected_labels, ( + f"Labels mismatch:\n" + f"Expected: {expected_labels}\n" + f"Actual: {labels}\n" + f"Input IDs: {input_ids}\n" + ) class TestSharegptChatTemplateLlama3: @@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3: Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy. """ - def test_llama3(self, llama3_tokenizer, sharegpt_dataset): - # pylint: disable=duplicate-code + def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") strategy = ChatTemplateStrategy( ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - llama3_tokenizer, - False, - 512, + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["gpt"], ) res = strategy.tokenize_prompt(sharegpt_dataset[0]) input_ids = res["input_ids"] + labels = res["labels"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot + 271, 15339, 128009, # assistant response eot 128006, 882, 128007, 271, 19045, 29474, 128009, 128006, 78191, 128007, 271, 19045, 29474, 128009, ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + + def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["human"], + ) + res = strategy.tokenize_prompt(sharegpt_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + + def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["system", "human"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 9125, 128007, + 271, 2675, 527, 459, 15592, 18328, 13, 128009, + 128006, 882, 128007, # user header + 271, 9906, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 13347, 1070, 0, 128009, # assistant response eot + 128006, 882, 128007, + 271, 4438, 527, 499, 30, 128009, + 128006, 78191, 128007, + 271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009, + ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header + IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + if __name__ == "__main__": unittest.main() diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index aba53cd5f..e7a73a0de 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -192,6 +192,7 @@ class TestSharegptLlama3: input_ids = dataset_wrapper[0]["input_ids"] # fmt: off + # pylint: disable=duplicate-code assert input_ids == [ 128000, # bos 128006, 9125, 128007, # system header @@ -228,6 +229,7 @@ class TestSharegptLlama3: input_ids = dataset_wrapper[0]["input_ids"] # fmt: off + # pylint: disable=duplicate-code assert input_ids == [ 128000, # bos 128006, 9125, 128007, # system header From 3bc8e64557c00f1adecf4a98f6c39f7ca77d3927 Mon Sep 17 00:00:00 2001 From: mhenrichsen Date: Tue, 30 Jul 2024 07:59:53 +0200 Subject: [PATCH 08/13] Update README.md (#1792) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fd293bd04..a626635dc 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. -See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. +See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. ### Config From d8d1788ffc4dd1d58bd5813a83abf6f1f8fad60f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 08:06:11 -0400 Subject: [PATCH 09/13] move to supporting mostly 12.1 w 2.3.1 and add new 12.4 with 2.4.0 (#1793) --- .github/workflows/base.yml | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index f8eaff270..4019a5baf 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -12,36 +12,21 @@ jobs: fail-fast: false matrix: include: - - cuda: "118" - cuda_version: 11.8.0 + - cuda: "121" + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 + pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "121" - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - - cuda: "121" - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.1.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - - cuda: "121" - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - - cuda: "121" - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.3.0 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - - cuda: "121" - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "124" + cuda_version: 12.4.0 + python_version: "3.11" + pytorch: 2.4.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 From d4f6a6b1032b87061c760f558093a9168dfe77c7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 08:34:37 -0400 Subject: [PATCH 10/13] fix dockerfile and base builder (#1795) [skip-ci] --- .github/workflows/base.yml | 4 ++++ docker/Dockerfile-base | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 4019a5baf..3a0c143df 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -14,16 +14,19 @@ jobs: include: - cuda: "121" cuda_version: 12.1.1 + cudnn_version: 8 python_version: "3.10" pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "121" cuda_version: 12.1.1 + cudnn_version: 8 python_version: "3.11" pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "124" cuda_version: 12.4.0 + cudnn_version: "" python_version: "3.11" pytorch: 2.4.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" @@ -52,6 +55,7 @@ jobs: labels: ${{ steps.metadata.outputs.labels }} build-args: | CUDA_VERSION=${{ matrix.cuda_version }} + CUDNN_VERSION=${{ matrix.cudnn_version }} CUDA=${{ matrix.cuda }} PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch }} diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 1de5537da..3f13bba30 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8" ARG UBUNTU_VERSION="22.04" ARG MAX_JOBS=4 -FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder +FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder ENV PATH="/root/miniconda3/bin:${PATH}" From c5587b45accdc50e0be7b2ec9ed3a66879d68156 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 08:50:23 -0400 Subject: [PATCH 11/13] use 12.4.1 instead of 12.4 [skip-ci] (#1796) --- .github/workflows/base.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 3a0c143df..9101fc2be 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -25,7 +25,7 @@ jobs: pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "124" - cuda_version: 12.4.0 + cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" pytorch: 2.4.0 From 9a638845977e269ed878de7eb25a313f094718ea Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 12:37:40 -0400 Subject: [PATCH 12/13] update test and main/nightly builds (#1797) * update test and main/nightly builds * don't install mamba-ssm on 2.4.0 since it has no wheels yet --- .github/workflows/main.yml | 53 +++++++++++++-------------------- .github/workflows/nightlies.yml | 47 +++++++++++------------------ .github/workflows/tests.yml | 28 ++++++++--------- cicd/Dockerfile.jinja | 4 +-- docker/Dockerfile | 4 +-- 5 files changed, 56 insertions(+), 80 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4969de75d..263af9788 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,28 +13,22 @@ jobs: fail-fast: false matrix: include: - - cuda: 118 - cuda_version: 11.8.0 + - cuda: 121 + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 - axolotl_extras: - axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" + pytorch: 2.3.1 + axolotl_extras: mamba-ssm - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 - axolotl_extras: + axolotl_extras: mamba-ssm is_latest: true + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -75,27 +69,22 @@ jobs: strategy: matrix: include: - - cuda: 118 - cuda_version: 11.8.0 + - cuda: 121 + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 + pytorch: 2.3.1 axolotl_extras: - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 axolotl_extras: is_latest: true + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -134,7 +123,7 @@ jobs: matrix: include: - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 axolotl_extras: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 770954b85..1d95a0983 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -12,28 +12,22 @@ jobs: fail-fast: false matrix: include: - - cuda: 118 - cuda_version: 11.8.0 + - cuda: 121 + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 - axolotl_extras: - axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 + pytorch: 2.3.1 axolotl_extras: - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 axolotl_extras: is_latest: true + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -75,27 +69,22 @@ jobs: strategy: matrix: include: - - cuda: 118 - cuda_version: 11.8.0 + - cuda: 121 + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 + pytorch: 2.3.1 axolotl_extras: - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 axolotl_extras: is_latest: true + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cee8cbcb..384f9d70a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,27 +72,24 @@ jobs: fail-fast: false matrix: include: - - cuda: 118 - cuda_version: 11.8.0 + - cuda: 121 + cuda_version: 12.1.1 python_version: "3.10" - pytorch: 2.1.2 - axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" + pytorch: 2.3.1 num_gpus: 1 + axolotl_extras: mamba-ssm - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.2 - num_gpus: 1 - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.11" - pytorch: 2.2.2 - num_gpus: 1 - - cuda: 121 - cuda_version: 12.1.0 + cuda_version: 12.1.1 python_version: "3.11" pytorch: 2.3.1 num_gpus: 1 + axolotl_extras: mamba-ssm + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + num_gpus: 1 + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 @@ -109,6 +106,7 @@ jobs: echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV + echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV - name: Run tests job on Modal diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 263f4a661..3a7988366 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/docker/Dockerfile b/docker/Dockerfile index be58d0354..2b106f1ed 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image From dbf8fb549e25ad69557aaba96a8f107055e4c3bf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 13:36:19 -0400 Subject: [PATCH 13/13] publish axolotl images without extras in the tag name (#1798) --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 263af9788..5a972f5f0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -59,6 +59,7 @@ jobs: push: ${{ github.event_name != 'pull_request' }} tags: | ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }}