Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
8fc4c420a4 Add kd coefficient scheduler 2025-03-18 09:01:58 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
19 changed files with 252 additions and 42 deletions

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -34,7 +35,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = ( process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment] process_count or os.cpu_count() # type: ignore[assignment]
) )
num_proc = min(64, process_or_cpu_count) num_proc = min(32, process_or_cpu_count)
features = data.features.keys() features = data.features.keys()
tokenized_data = data.map( tokenized_data = data.map(
map_fn, map_fn,

View File

@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
float float
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[

View File

@@ -0,0 +1,28 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training Transform logprobs to target format for KD training
""" """
logprobs = sample.pop(self.logprobs_field) if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len target_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [ top_k_vals = [
len(logprobs[i]) len(logprobs[i])
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0: if target_padding_len < 0:
# logprobs is longer than target_seq_len, # logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs # so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len] logprobs = logprobs[:-input_seq_len]
input_padding_len = 0 target_padding_len = 0
# target_seq_len = input_seq_len # target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k # truncate the second dimension of the logprobs to top_k
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer # for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer # otherwise, we need to shift in the trainer
shift = 0 shift = 0
for _ in range(shift, input_padding_len): for _ in range(shift, target_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len): for position in range(target_padding_len, input_seq_len):
if sample["labels"][position] == -100: if sample["labels"][position] == -100:
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
else: else:
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
position_logprobs = [] position_logprobs = []
position_token_ids = [] position_token_ids = []
# Process each token probability entry # Process each token probability entry
for entry in token_pos_logprobs: if token_pos_token_ids is None:
# Extract logprob value for entry in token_pos_logprobs:
logprob = entry["logprob"] # Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format # Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1]) token_id = int(entry["token"].split(":")[1])
# Append to our collections # Append to our collections
position_logprobs.append(logprob) position_logprobs.append(logprob)
position_token_ids.append(token_id) position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
else: else:
teacher_probs_t2 = teacher_probs_t1 teacher_probs_t2 = teacher_probs_t1
# Re-normalize # Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True dim=0, keepdim=True

View File

@@ -16,17 +16,35 @@
KD trainer KD trainer
""" """
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()
columns_to_add = [] columns_to_add = []
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
) )
if self.args.kd_ce_alpha > 0: if self.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
# Save past state if it exists # Save past state if it exists

View File

@@ -7,7 +7,7 @@ import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Dict
import torch import torch
import transformers.modelcard import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens: if not cfg.fix_untrained_tokens:
return return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens) sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance( if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list cfg.fix_untrained_tokens, list
): ):
fix_untrained_tokens( fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
model, if "is_ds_zero3" in sig.parameters:
tokenizer, fix_kwargs["is_ds_zero3"] = is_ds_zero3
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens, fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(

View File

@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control return control

View File

@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
] ]
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
else: else:
arrays = [ try:
np.array(item[feature]) for item in features_ if feature in item arrays = [
] np.array(item[feature])
out_features[i][feature] = np.concatenate(arrays) for item in features_
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors) return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -728,7 +728,7 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
dataset_processes: Optional[int] = Field(default=os.cpu_count()) dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_exact_deduplication: Optional[bool] = None dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None dataloader_pin_memory: Optional[bool] = None

View File

@@ -24,7 +24,6 @@ from peft import (
PeftModelForCausalLM, PeftModelForCausalLM,
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.tuners.lora import QuantLinear
from torch import nn from torch import nn
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
@@ -1360,7 +1359,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"sequence_len": 2048, "sequence_len": 2048,
"rl": True, "rl": True,
"adam_beta1": 0.998, "adam_beta1": 0.91,
"adam_beta2": 0.9, "adam_beta2": 0.998,
"adam_epsilon": 0.00001, "adam_epsilon": 0.00001,
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta1 == 0.91
assert training_arguments.adam_beta2 == 0.9 assert training_arguments.adam_beta2 == 0.998
assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True assert training_arguments.dataloader_pin_memory is True

View File

@@ -750,3 +750,66 @@ class TestMultiGPULlama:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,6 +66,54 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {