diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml new file mode 100644 index 000000000..86d9b43f8 --- /dev/null +++ b/examples/glm4/qlora-32b.yaml @@ -0,0 +1,62 @@ +base_model: THUDM/GLM-4-32B-0414 +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_4bit: true + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index e18d7df06..724e0688d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -47,6 +47,8 @@ cut_cross_entropy: true - qwen2 - cohere - cohere2 +- glm +- glm4 ## Citation diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py new file mode 100644 index 000000000..3df909f88 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py @@ -0,0 +1,57 @@ +"""GLM 4 patch. GLM family inherits from Llama.""" + +from types import MethodType + +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, +) + + +def patch_glm( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + + # Set the _PATCH_OPTS in the llama patch file + import cut_cross_entropy.transformers.llama as llama_patch + + llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access + + from cut_cross_entropy.transformers.llama import cce_forward + from transformers.models.glm import modeling_glm + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_glm.GlmForCausalLM + ), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_glm.GlmForCausalLM.forward = cce_forward + return None + + +def patch_glm4( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + + # Set the _PATCH_OPTS in the llama patch file + import cut_cross_entropy.transformers.llama as llama_patch + + llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access + + from cut_cross_entropy.transformers.llama import cce_forward + from transformers.models.glm4 import modeling_glm4 + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_glm4.Glm4ForCausalLM + ), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_glm4.Glm4ForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py index 5263956ce..9e18c6b0b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py @@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( patch_gemma3, patch_gemma3_text, ) +from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import ( + patch_glm, + patch_glm4, +) from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( patch_llama4, patch_llama4_text, @@ -45,6 +49,8 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = { "qwen2": patch_qwen2, "cohere": patch_cohere, "cohere2": patch_cohere2, + "glm": patch_glm, + "glm4": patch_glm4, } diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 2b02699bd..a2459ec5a 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -31,6 +31,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "starcoder2", "deepseek_v2", "deepseek_v3", + "glm", + "glm4", ]