Compare commits

..

8 Commits

Author SHA1 Message Date
Wing Lian
9cb05283b2 use v2 branch 2025-03-10 19:46:19 -04:00
Wing Lian
aafa6245f4 try with deepspeed import 2025-03-10 19:39:55 -04:00
Wing Lian
3001e6d93c use commit sha for previous release dev 2025-03-10 18:41:15 -04:00
Wing Lian
ed0456557d use revised branch 2025-03-10 16:53:57 -04:00
Wing Lian
09e4393a6a use branch again 2025-03-10 16:48:33 -04:00
Wing Lian
31a81106dd revert to previous known good commit 2025-03-10 16:36:33 -04:00
Wing Lian
93c20cc0d5 test branch 2025-03-10 16:35:17 -04:00
Wing Lian
3f5e2d6cc9 bump axolotl-contribs-lgpl 2025-03-10 16:35:17 -04:00
19 changed files with 42 additions and 252 deletions

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl @ git+https://github.com/axolotl-ai-cloud/axolotl-contribs-lgpl.git@import-issues-v2
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -35,8 +34,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token()
check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = ( process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment] process_count or os.cpu_count() # type: ignore[assignment]
) )
num_proc = min(32, process_or_cpu_count) num_proc = min(64, process_or_cpu_count)
features = data.features.keys() features = data.features.keys()
tokenized_data = data.map( tokenized_data = data.map(
map_fn, map_fn,

View File

@@ -751,12 +751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -34,12 +34,3 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,8 +30,6 @@ class KDArgs(BaseModel):
float float
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[

View File

@@ -1,28 +0,0 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,16 +62,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training Transform logprobs to target format for KD training
""" """
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys(): logprobs = sample.pop(self.logprobs_field)
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
target_padding_len = input_seq_len - target_seq_len input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [ top_k_vals = [
len(logprobs[i]) len(logprobs[i])
@@ -88,11 +82,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if target_padding_len < 0: if input_padding_len < 0:
# logprobs is longer than target_seq_len, # logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs # so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len] logprobs = logprobs[:-input_seq_len]
target_padding_len = 0 input_padding_len = 0
# target_seq_len = input_seq_len # target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k # truncate the second dimension of the logprobs to top_k
@@ -104,37 +98,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer # for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer # otherwise, we need to shift in the trainer
shift = 0 shift = 0
for _ in range(shift, target_padding_len): for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for position in range(target_padding_len, input_seq_len): for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100: if sample["labels"][position] == -100:
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
else: else:
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
position_logprobs = [] position_logprobs = []
position_token_ids = [] position_token_ids = []
# Process each token probability entry # Process each token probability entry
if token_pos_token_ids is None: for entry in token_pos_logprobs:
for entry in token_pos_logprobs: # Extract logprob value
# Extract logprob value logprob = entry["logprob"]
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format # Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1]) token_id = int(entry["token"].split(":")[1])
# Append to our collections # Append to our collections
position_logprobs.append(logprob) position_logprobs.append(logprob)
position_token_ids.append(token_id) position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
@@ -153,7 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
else: else:
teacher_probs_t2 = teacher_probs_t1 teacher_probs_t2 = teacher_probs_t1
# Re-normalize # Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True dim=0, keepdim=True

View File

@@ -16,35 +16,17 @@
KD trainer KD trainer
""" """
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()
columns_to_add = [] columns_to_add = []
@@ -113,8 +95,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
) )
if self.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
# Save past state if it exists # Save past state if it exists

View File

@@ -7,7 +7,7 @@ import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any
import torch import torch
import transformers.modelcard import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,23 +382,21 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens: if not cfg.fix_untrained_tokens:
return return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens) sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance( if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list cfg.fix_untrained_tokens, list
): ):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens fix_untrained_tokens(
if "is_ds_zero3" in sig.parameters: model,
fix_kwargs["is_ds_zero3"] = is_ds_zero3 tokenizer,
train_dataset,
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs) token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(

View File

@@ -813,15 +813,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control return control

View File

@@ -173,16 +173,10 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
] ]
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
else: else:
try: arrays = [
arrays = [ np.array(item[feature]) for item in features_ if feature in item
np.array(item[feature]) ]
for item in features_ out_features[i][feature] = np.concatenate(arrays)
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors) return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -728,7 +728,7 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None dataloader_pin_memory: Optional[bool] = None

View File

@@ -24,6 +24,7 @@ from peft import (
PeftModelForCausalLM, PeftModelForCausalLM,
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.tuners.lora import QuantLinear
from torch import nn from torch import nn
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
@@ -1359,7 +1360,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"sequence_len": 2048, "sequence_len": 2048,
"rl": True, "rl": True,
"adam_beta1": 0.91, "adam_beta1": 0.998,
"adam_beta2": 0.998, "adam_beta2": 0.9,
"adam_epsilon": 0.00001, "adam_epsilon": 0.00001,
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.91 assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.998 assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True assert training_arguments.dataloader_pin_memory is True

View File

@@ -750,66 +750,3 @@ class TestMultiGPULlama:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,54 +66,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {