diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 669ffacdc..8852b3469 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -5,6 +5,9 @@ tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + load_in_8bit: false load_in_4bit: true strict: false @@ -54,6 +57,8 @@ fp16: tf32: true gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false early_stopping_patience: resume_from_checkpoint: local_rank: diff --git a/examples/gemma3/gemma-3-4b-lora.yml b/examples/gemma3/gemma-3-4b-lora.yml index b85392982..0e7422bd4 100644 --- a/examples/gemma3/gemma-3-4b-lora.yml +++ b/examples/gemma3/gemma-3-4b-lora.yml @@ -7,6 +7,9 @@ skip_prepare_dataset: true remove_unused_columns: false sample_packing: false +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + chat_template: gemma3 datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft @@ -48,6 +51,8 @@ fp16: tf32: true gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false local_rank: logging_steps: 1 flash_attention: true diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2349932ba..7d0df8a45 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -524,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): and self.cfg.eval_steps and self.cfg.save_steps % self.cfg.eval_steps == 0 ) or False + + # handle ddp + ddp_find_unused_parameters = None + if self.cfg.ddp: + ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters) training_arguments_kwargs["ddp_find_unused_parameters"] = ( - False if self.cfg.ddp else None + ddp_find_unused_parameters ) + training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling report_to = [] diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index cd819fba4..015743329 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "phi3", "gemma", "gemma2", + "gemma3", "gemma3_text", "cohere", "cohere2", diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 12c8b31d5..33bb4b4cc 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq: self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): + has_attn_mask = "attention_mask" in features[0].keys() labels = None if return_tensors is None: return_tensors = self.return_tensors @@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq: pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) + if not has_attn_mask: + del features["attention_mask"] # prepare decoder_input_ids if ( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 646fb4c87..c370707b6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - if cfg.model_config_type == "mamba": + if cfg.model_config_type in ["mamba", "gemma3"]: LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py new file mode 100644 index 000000000..9de3ed82f --- /dev/null +++ b/tests/e2e/multigpu/test_gemma3.py @@ -0,0 +1,100 @@ +""" +E2E tests for multigpu lora tinyllama +""" + +import logging +import os +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from huggingface_hub import snapshot_download +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_tensorboard + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +@pytest.fixture(scope="session", autouse=True) +def download_model(): + # download the model + snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model") + + +class TestMultiGPUGemma3: + """ + Test case for Gemma3 models using LoRA + """ + + def test_lora_ddp_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-mirrors/gemma-3-4b-pt", + "sequence_len": 2048, + "ddp_find_unused_parameters": True, + "sample_packing": True, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.0, + "chat_template": "gemma3", + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 4, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": { + "use_reentrant": False, + }, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high" + ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 432d89b1f..0e228aef0 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -58,6 +58,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 4, "gradient_accumulation_steps": 4, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_8bit", @@ -121,6 +122,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 1, "gradient_accumulation_steps": gradient_accumulation_steps, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_8bit", @@ -193,6 +195,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 4, "gradient_accumulation_steps": 4, + "gradient_checkpointing": True, "output_dir": temp_dir, "warmup_steps": 0, "learning_rate": 0.00001, @@ -270,6 +273,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 4, + "gradient_checkpointing": True, "output_dir": temp_dir, "warmup_steps": 0, "learning_rate": 0.00001, @@ -330,6 +334,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": gradient_accumulation_steps, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", @@ -400,6 +405,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 4, "gradient_accumulation_steps": 2, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", @@ -479,6 +485,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 4, "gradient_accumulation_steps": 2, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", @@ -781,6 +788,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 1, "gradient_accumulation_steps": 1, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused",