pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading * use latest version of vendored lib * use v0.0.5 of contribs lgpl * fix for no bad tokens and add tests * use release * add multigpu test too * make sure the multigpu zero3 test actually uses zero3
This commit is contained in:
@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
@@ -7,7 +7,7 @@ import signal
|
||||
import sys
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
|
||||
is_ds_zero3: bool = False
|
||||
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||
is_ds_zero3 = True
|
||||
|
||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
|
||||
fix_kwargs: Dict[str, Any] = {}
|
||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
||||
if "is_ds_zero3" in sig.parameters:
|
||||
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
||||
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
|
||||
@@ -750,3 +750,66 @@ class TestMultiGPULlama:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"fix_untrained_tokens": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"bos_token": "<|custom_im_start|>",
|
||||
"eos_token": "<|custom_im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@@ -66,6 +66,54 @@ class TestLlama:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"fix_untrained_tokens": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"bos_token": "<|custom_im_start|>",
|
||||
"eos_token": "<|custom_im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user