Compare commits
6 Commits
update-lgp
...
kd-logprob
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fc4c420a4 | ||
|
|
4f5eb42a73 | ||
|
|
fbe54be6b8 | ||
|
|
04f6324833 | ||
|
|
f0072f3b9d | ||
|
|
59899b9817 |
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -88,6 +88,11 @@ jobs:
|
|||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,6 +80,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.3
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -34,7 +35,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
|
|||||||
process_or_cpu_count: int = (
|
process_or_cpu_count: int = (
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
process_count or os.cpu_count() # type: ignore[assignment]
|
||||||
)
|
)
|
||||||
num_proc = min(64, process_or_cpu_count)
|
num_proc = min(32, process_or_cpu_count)
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
|
|||||||
@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
|
if self.cfg.kd_ce_alpha_end is not None:
|
||||||
|
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||||
|
if self.cfg.kd_alpha_end is not None:
|
||||||
|
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
|
||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
|
|||||||
@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
|
|||||||
|
|
||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
callbacks = []
|
||||||
|
if cfg.kd_trainer:
|
||||||
|
from .callbacks import KDAlphaSchedulerCallback
|
||||||
|
|
||||||
|
callbacks.append(KDAlphaSchedulerCallback())
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
|
|||||||
float
|
float
|
||||||
] = None # loss coefficient for cross-entropy loss during KD
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
|
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
|
||||||
|
kd_alpha_end: Optional[float] = None # end value for kd_alpha
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[
|
kd_top_k_before_softmax: Optional[
|
||||||
|
|||||||
28
src/axolotl/integrations/kd/callbacks.py
Normal file
28
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
class KDAlphaSchedulerCallback(TrainerCallback):
|
||||||
|
"""Callback to for scheduling KD alpha during training."""
|
||||||
|
|
||||||
|
def on_epoch_begin(
|
||||||
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if int(state.epoch) == 0:
|
||||||
|
state.kd_alpha = args.kd_alpha
|
||||||
|
state.kd_ce_alpha = args.kd_ce_alpha
|
||||||
|
elif int(state.epoch) == state.num_train_epochs - 1:
|
||||||
|
if args.kd_alpha_end is not None:
|
||||||
|
control.kd_alpha = args.kd_alpha_end
|
||||||
|
if args.kd_ce_alpha_end is not None:
|
||||||
|
control.kd_ce_alpha = args.kd_ce_alpha_end
|
||||||
|
else:
|
||||||
|
epoch_steps = state.num_train_epochs - 1
|
||||||
|
scale = int(state.epoch) / epoch_steps
|
||||||
|
if args.kd_alpha_end is not None:
|
||||||
|
control.kd_alpha = (
|
||||||
|
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
|
||||||
|
)
|
||||||
|
if args.kd_ce_alpha_end is not None:
|
||||||
|
control.kd_ce_alpha = (
|
||||||
|
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
|
||||||
|
)
|
||||||
@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
Transform logprobs to target format for KD training
|
Transform logprobs to target format for KD training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
|
||||||
|
logprobs = sample.pop("target_logprobs")
|
||||||
|
token_ids = sample.pop("target_token_ids")
|
||||||
|
else:
|
||||||
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
|
token_ids = [None] * len(logprobs)
|
||||||
|
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
input_padding_len = input_seq_len - target_seq_len
|
target_padding_len = input_seq_len - target_seq_len
|
||||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
top_k_vals = [
|
top_k_vals = [
|
||||||
len(logprobs[i])
|
len(logprobs[i])
|
||||||
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
|
|
||||||
if input_padding_len < 0:
|
if target_padding_len < 0:
|
||||||
# logprobs is longer than target_seq_len,
|
# logprobs is longer than target_seq_len,
|
||||||
# so we need to slice from the left/beginning of logprobs
|
# so we need to slice from the left/beginning of logprobs
|
||||||
logprobs = logprobs[:-input_seq_len]
|
logprobs = logprobs[:-input_seq_len]
|
||||||
input_padding_len = 0
|
target_padding_len = 0
|
||||||
# target_seq_len = input_seq_len
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
# truncate the second dimension of the logprobs to top_k
|
# truncate the second dimension of the logprobs to top_k
|
||||||
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||||
# otherwise, we need to shift in the trainer
|
# otherwise, we need to shift in the trainer
|
||||||
shift = 0
|
shift = 0
|
||||||
for _ in range(shift, input_padding_len):
|
for _ in range(shift, target_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
for position in range(input_padding_len, input_seq_len):
|
for position in range(target_padding_len, input_seq_len):
|
||||||
if sample["labels"][position] == -100:
|
if sample["labels"][position] == -100:
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
else:
|
else:
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
position_logprobs = []
|
position_logprobs = []
|
||||||
position_token_ids = []
|
position_token_ids = []
|
||||||
|
|
||||||
# Process each token probability entry
|
# Process each token probability entry
|
||||||
for entry in token_pos_logprobs:
|
if token_pos_token_ids is None:
|
||||||
# Extract logprob value
|
for entry in token_pos_logprobs:
|
||||||
logprob = entry["logprob"]
|
# Extract logprob value
|
||||||
|
logprob = entry["logprob"]
|
||||||
|
|
||||||
# Parse token_id from the "token_id:###" format
|
# Parse token_id from the "token_id:###" format
|
||||||
token_id = int(entry["token"].split(":")[1])
|
token_id = int(entry["token"].split(":")[1])
|
||||||
|
|
||||||
# Append to our collections
|
# Append to our collections
|
||||||
position_logprobs.append(logprob)
|
position_logprobs.append(logprob)
|
||||||
position_token_ids.append(token_id)
|
position_token_ids.append(token_id)
|
||||||
|
else:
|
||||||
|
position_logprobs = token_pos_logprobs
|
||||||
|
position_token_ids = token_pos_token_ids
|
||||||
|
|
||||||
# Convert to a tensor for easier manipulation
|
# Convert to a tensor for easier manipulation
|
||||||
position_logprobs_tensor = torch.tensor(
|
position_logprobs_tensor = torch.tensor(
|
||||||
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
else:
|
else:
|
||||||
teacher_probs_t2 = teacher_probs_t1
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
dim=0, keepdim=True
|
dim=0, keepdim=True
|
||||||
|
|||||||
@@ -16,17 +16,35 @@
|
|||||||
KD trainer
|
KD trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from transformers import TrainerControl
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKDTrainerControl(TrainerControl):
|
||||||
|
kd_alpha: float = 1.0
|
||||||
|
kd_ce_alpha: float = 0.0
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
state_val = super().state()
|
||||||
|
state_val["args"]["kd_alpha"] = self.kd_alpha
|
||||||
|
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.kd_alpha = self.args.kd_alpha
|
||||||
|
self.kd_ce_alpha = self.args.kd_ce_alpha
|
||||||
|
self.control = AxolotlKDTrainerControl()
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
super()._set_signature_columns_if_needed()
|
super()._set_signature_columns_if_needed()
|
||||||
columns_to_add = []
|
columns_to_add = []
|
||||||
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
if self.kd_ce_alpha > 0:
|
||||||
kd_alpha = self.args.kd_alpha
|
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
|
||||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
|
||||||
else:
|
else:
|
||||||
loss = loss_kd
|
loss = loss_kd
|
||||||
# Save past state if it exists
|
# Save past state if it exists
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
|
|||||||
if not cfg.fix_untrained_tokens:
|
if not cfg.fix_untrained_tokens:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
is_ds_zero3: bool = False
|
||||||
|
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||||
|
is_ds_zero3 = True
|
||||||
|
|
||||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
sig = inspect.signature(fix_untrained_tokens)
|
||||||
|
|
||||||
|
fix_kwargs: Dict[str, Any] = {}
|
||||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||||
cfg.fix_untrained_tokens, list
|
cfg.fix_untrained_tokens, list
|
||||||
):
|
):
|
||||||
fix_untrained_tokens(
|
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
||||||
model,
|
if "is_ds_zero3" in sig.parameters:
|
||||||
tokenizer,
|
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
|
|||||||
@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
|
# TODO if using deepspeed and it's a file, save deepspeed config too
|
||||||
|
if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||||
|
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
|
||||||
|
)
|
||||||
|
artifact.add_file(args.deepspeed)
|
||||||
|
wandb.log_artifact(artifact)
|
||||||
|
wandb.save(args.deepspeed)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
else:
|
else:
|
||||||
arrays = [
|
try:
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
arrays = [
|
||||||
]
|
np.array(item[feature])
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
for item in features_
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
if arrays[0].dtype != "object":
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -728,7 +728,7 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
||||||
dataset_exact_deduplication: Optional[bool] = None
|
dataset_exact_deduplication: Optional[bool] = None
|
||||||
dataset_keep_in_memory: Optional[bool] = None
|
dataset_keep_in_memory: Optional[bool] = None
|
||||||
dataloader_pin_memory: Optional[bool] = None
|
dataloader_pin_memory: Optional[bool] = None
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from peft.tuners.lora import QuantLinear
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
@@ -1360,7 +1359,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ def fixture_cfg():
|
|||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"rl": True,
|
"rl": True,
|
||||||
"adam_beta1": 0.998,
|
"adam_beta1": 0.91,
|
||||||
"adam_beta2": 0.9,
|
"adam_beta2": 0.998,
|
||||||
"adam_epsilon": 0.00001,
|
"adam_epsilon": 0.00001,
|
||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
|
|||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.998
|
assert training_arguments.adam_beta1 == 0.91
|
||||||
assert training_arguments.adam_beta2 == 0.9
|
assert training_arguments.adam_beta2 == 0.998
|
||||||
assert training_arguments.adam_epsilon == 0.00001
|
assert training_arguments.adam_epsilon == 0.00001
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
assert training_arguments.dataloader_pin_memory is True
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
|
|||||||
@@ -750,3 +750,66 @@ class TestMultiGPULlama:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"fix_untrained_tokens": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"bos_token": "<|custom_im_start|>",
|
||||||
|
"eos_token": "<|custom_im_end|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -66,6 +66,54 @@ class TestLlama:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_fix_untrained_tokens(self, temp_dir):
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"fix_untrained_tokens": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"bos_token": "<|custom_im_start|>",
|
||||||
|
"eos_token": "<|custom_im_end|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user