pass additional info for fix untrained tokens when using distributed + offloading (#2388)

* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
This commit is contained in:
Wing Lian
2025-03-11 12:02:43 -04:00
committed by GitHub
parent 4a736986fa
commit 59899b9817
4 changed files with 124 additions and 11 deletions

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3

View File

@@ -7,7 +7,7 @@ import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Dict
import torch import torch
import transformers.modelcard import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens: if not cfg.fix_untrained_tokens:
return return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens) sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance( if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list cfg.fix_untrained_tokens, list
): ):
fix_untrained_tokens( fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
model, if "is_ds_zero3" in sig.parameters:
tokenizer, fix_kwargs["is_ds_zero3"] = is_ds_zero3
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens, fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(

View File

@@ -750,3 +750,66 @@ class TestMultiGPULlama:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,6 +66,54 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {