diff --git a/docs/config.qmd b/docs/config.qmd index 9946b5865..f166f8050 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -513,7 +513,6 @@ lr_div_factor: # Learning rate div factor # in the examples/ for your model and fine-tuning use case. # # Valid values for 'optimizer' include: -# - adamw_hf # - adamw_torch # - adamw_torch_fused # - adamw_torch_xla diff --git a/examples/gemma3/qlora.yml b/examples/gemma3/qlora.yml new file mode 100644 index 000000000..50045cc8a --- /dev/null +++ b/examples/gemma3/qlora.yml @@ -0,0 +1,74 @@ +base_model: google/gemma-3-1b-it +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true +strict: false + +# huggingface repo +chat_template: gemma3_text +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: true + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/requirements.txt b/requirements.txt index c8465d23f..93618ba00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ liger-kernel==0.5.3 packaging==23.2 peft==0.15.0 -transformers==4.49.0 +transformers==4.50.0 tokenizers>=0.21.1 accelerate==1.5.2 datasets==3.4.1 diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index b8e1fac52..327a05138 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -114,3 +114,5 @@ class LigerPlugin(BasePlugin): modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]: + raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index f2a87192a..d6d209db5 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "phi3", "gemma", "gemma2", + "gemma3_text", "gemmoe", "starcoder2", "deepseek_v2", diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index d3c88334b..7dbeda462 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -22,6 +22,7 @@ _CHAT_TEMPLATES = { "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral... "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "gemma3_text": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index f376aca5f..a0c6df710 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -23,6 +23,7 @@ class ChatTemplate(str, Enum): mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name + gemma3_text = "gemma3_text" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 89f80951b..ce7a2bf0f 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model): def test_geglu_model_integration(): """Test GeGLU activation with Gemma model.""" model = AutoModelForCausalLM.from_pretrained( - "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda" + "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto" ) peft_config = get_peft_config( { @@ -347,7 +347,7 @@ def test_model_architecture(model_config): """Test LoRA kernel patches across different model architectures.""" # Load model with appropriate dtype model = AutoModelForCausalLM.from_pretrained( - model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda" + model_config["name"], torch_dtype=model_config["dtype"], device_map="auto" ) # Apply LoRA configuration diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index de8513078..f8c3d429a 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -1,5 +1,5 @@ """ -E2E tests for lora llama +E2E tests for deepseekv3 """ import logging diff --git a/tests/e2e/test_gemma2.py b/tests/e2e/test_gemma2.py new file mode 100644 index 000000000..df777b709 --- /dev/null +++ b/tests/e2e/test_gemma2.py @@ -0,0 +1,133 @@ +""" +E2E tests for gemma2 +""" + +import logging +import os +from pathlib import Path + +import pytest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestGemma2: + """ + Test case for Gemma2 models + """ + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_lora_gemma2(self, temp_dir, sample_packing): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/gemma-2-33M", + "trust_remote_code": True, + "sample_packing": sample_packing, + "flash_attention": True, + "sequence_len": 2048, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0, + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "field_messages": "conversations", + "message_property_mappings": { + "role": "from", + "content": "value", + }, + "drop_system_message": True, + "split": "train[:1%]", + }, + ], + "special_tokens": { + "bos_token": "", + "eos_token": "", + }, + "chat_template": "gemma", # gemma2's template is same as gemma + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_fft_gemma2(self, temp_dir, sample_packing): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/gemma-2-33M", + "trust_remote_code": True, + "sample_packing": sample_packing, + "flash_attention": True, + "sequence_len": 2048, + "val_set_size": 0, + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "field_messages": "conversations", + "message_property_mappings": { + "role": "from", + "content": "value", + }, + "split": "train[:1%]", + "drop_system_message": True, + }, + ], + "chat_template": "gemma", # gemma2's template is same as gemma + "special_tokens": { + "bos_token": "", + "eos_token": "", + }, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py new file mode 100644 index 000000000..6ce360f68 --- /dev/null +++ b/tests/e2e/test_gemma3_text.py @@ -0,0 +1,131 @@ +""" +E2E tests for gemma3_text +""" + +import logging +import os +from pathlib import Path + +import pytest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestGemma3Text: + """ + Test case for Gemma3Text models + """ + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_lora_gemma3_text(self, temp_dir, sample_packing): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/gemma-3-34M", + "trust_remote_code": True, + "sample_packing": sample_packing, + "flash_attention": True, + "sequence_len": 2048, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0, + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "field_messages": "conversations", + "message_property_mappings": { + "role": "from", + "content": "value", + }, + "split": "train[:1%]", + }, + ], + "special_tokens": { + "bos_token": "", + "eos_token": "", + }, + "chat_template": "gemma3_text", + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_fft_gemma3_text(self, temp_dir, sample_packing): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/gemma-3-34M", + "trust_remote_code": True, + "sample_packing": sample_packing, + "flash_attention": True, + "sequence_len": 2048, + "val_set_size": 0, + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "field_messages": "conversations", + "message_property_mappings": { + "role": "from", + "content": "value", + }, + "split": "train[:1%]", + }, + ], + "chat_template": "gemma3_text", + "special_tokens": { + "bos_token": "", + "eos_token": "", + }, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index c492fdccf..2d5040ae3 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase): "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_hf", + "optimizer": "adamw_torch_fused", "max_steps": 20, "lr_scheduler": "rex", "warmup_steps": 5,