Compare commits
1 Commits
train-refa
...
fix/replac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10d18e6c97 |
@@ -24,8 +24,8 @@ class TrainDatasetMeta:
|
|||||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||||
|
|
||||||
train_dataset: Dataset
|
train_dataset: Dataset
|
||||||
eval_dataset: Dataset | None = None
|
eval_dataset: Optional[Dataset] = None
|
||||||
total_num_steps: int | None = None
|
total_num_steps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||||
|
|||||||
@@ -91,11 +91,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""Base class for trainer builder."""
|
"""
|
||||||
|
Base class for trainer builder
|
||||||
|
"""
|
||||||
|
|
||||||
_train_dataset = None
|
_train_dataset = None
|
||||||
_eval_dataset = None
|
_eval_dataset = None
|
||||||
@@ -108,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
# If the model supports tagging, add the axolotl tag.
|
# in case the model supports tagging, add the axolotl tag.
|
||||||
# This makes sure the tag is correctly pushed even if a user calls
|
# This makes sure the tag is correctly pushed even if a user calls
|
||||||
# model.push_to_hub instead of trainer.push_to_hub.
|
# model.push_to_hub instad of trainer.push_to_hub.
|
||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
@@ -225,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Build the HuggingFace training args/trainer for causal models and reward modeling
|
Build the HuggingFace training args/trainer for causal models
|
||||||
using TRL.
|
and reward modelling using TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -870,7 +872,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
|
|
||||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
"""
|
||||||
|
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
|
||||||
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
|||||||
@@ -1,29 +1,26 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from peft import PeftModel
|
||||||
from peft import PeftConfig, PeftModel
|
from pkg_resources import get_distribution # type: ignore
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -35,25 +32,17 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer(
|
def train(
|
||||||
cfg: DictDefault,
|
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
|
|
||||||
`None`), and processor (if multimodal, else `None`).
|
|
||||||
"""
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
@@ -66,58 +55,11 @@ def setup_model_and_tokenizer(
|
|||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
processor = load_processor(cfg, tokenizer)
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
# Load the model and peft_config
|
# Get datasets
|
||||||
msg = "loading model"
|
train_dataset = dataset_meta.train_dataset
|
||||||
if cfg.adapter:
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
msg += " and peft_config..."
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
LOG.debug(msg)
|
|
||||||
|
|
||||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
|
||||||
if model.generation_config is not None:
|
|
||||||
model.generation_config.do_sample = True
|
|
||||||
|
|
||||||
# Apply freezing if specified
|
|
||||||
if cfg.unfrozen_parameters:
|
|
||||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
|
||||||
|
|
||||||
|
|
||||||
def setup_reference_model(
|
|
||||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
|
||||||
) -> PreTrainedModel | None:
|
|
||||||
"""
|
|
||||||
Set up the reference model for RL training if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
tokenizer: The tokenizer to use for the reference model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reference model if needed for RL training, `None` otherwise.
|
|
||||||
"""
|
|
||||||
model_ref = None
|
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
|
||||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
|
||||||
# use built-in trl autounwrap
|
|
||||||
LOG.debug("Passing model_ref: None to RL trainer")
|
|
||||||
model_ref = None # explicit setting to None
|
|
||||||
else:
|
|
||||||
# load the model again for model_ref/baseline
|
|
||||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
|
||||||
return model_ref
|
|
||||||
|
|
||||||
|
|
||||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|
||||||
"""
|
|
||||||
Determine the checkpoint to resume from based on configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
|
||||||
"""
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
@@ -131,22 +73,77 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
return cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
msg = "loading model"
|
||||||
|
if cfg.adapter:
|
||||||
|
msg += " and peft_config..."
|
||||||
|
LOG.debug(msg)
|
||||||
|
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||||
|
if model.generation_config is not None:
|
||||||
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
def setup_signal_handler(
|
model_ref = None
|
||||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
if cfg.rl and cfg.rl != "orpo":
|
||||||
|
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
|
# use built-in trl autounwrap
|
||||||
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
|
model_ref = None # explicit setting to None
|
||||||
|
else:
|
||||||
|
# load the model again for model_ref/baseline
|
||||||
|
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||||
|
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
|
trainer = setup_trainer(
|
||||||
|
cfg,
|
||||||
|
train_dataset,
|
||||||
|
eval_dataset,
|
||||||
|
(model, model_ref, peft_config),
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.fix_untrained_tokens:
|
||||||
|
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||||
|
sig = inspect.signature(fix_untrained_tokens)
|
||||||
|
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||||
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||||
|
cfg.fix_untrained_tokens, list
|
||||||
):
|
):
|
||||||
"""
|
fix_untrained_tokens(
|
||||||
Set up signal handler for graceful termination.
|
model,
|
||||||
|
tokenizer,
|
||||||
|
train_dataset,
|
||||||
|
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
if peft_config:
|
||||||
model: The model to save on termination
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
safe_serialization: Whether to use safe serialization when saving
|
peft_config.save_pretrained(cfg.output_dir)
|
||||||
"""
|
# additionally presave the tokenizer and model configs
|
||||||
# ray workers don't have access to this signal
|
if not Path(cfg.output_dir).is_dir():
|
||||||
if cfg.local_rank == 0 and not cfg.use_ray:
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
|
||||||
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
|
if (
|
||||||
|
cfg.local_rank == 0 and not cfg.use_ray
|
||||||
|
): # ray workers don't have access to this signal
|
||||||
|
|
||||||
def terminate_handler(_, __, model_weakref):
|
def terminate_handler(_, __, model_weakref):
|
||||||
if model_weakref() is not None:
|
if model_weakref() is not None:
|
||||||
@@ -164,22 +161,21 @@ def setup_signal_handler(
|
|||||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||||
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
def execute_training(
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||||
):
|
version = get_distribution("axolotl").version
|
||||||
"""
|
if raw_axolotl_cfg.is_file():
|
||||||
Execute the training process with appropriate backend configurations.
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
trainer: The configured trainer object.
|
|
||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
|
||||||
"""
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
|
pretrain_hooks(cfg, trainer)
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
@@ -191,30 +187,15 @@ def execute_training(
|
|||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
post_train_hooks(cfg, trainer)
|
||||||
|
|
||||||
def save_trained_model(
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
cfg: DictDefault,
|
|
||||||
trainer: Any,
|
|
||||||
model: PreTrainedModel,
|
|
||||||
safe_serialization: bool,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Save the trained model according to configuration and training setup.
|
|
||||||
|
|
||||||
Args:
|
# post training
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
trainer: The trainer object.
|
|
||||||
model: The trained model to save.
|
|
||||||
safe_serialization: Whether to use safe serialization.
|
|
||||||
"""
|
|
||||||
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
|
||||||
|
|
||||||
# Post training module hooks
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, "_post_training"):
|
if hasattr(module, "_post_training"):
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
|
||||||
# Handle FSDP state dict type
|
|
||||||
state_dict_type = "FULL_STATE_DICT"
|
state_dict_type = "FULL_STATE_DICT"
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled:
|
||||||
if cfg.fsdp_final_state_dict_type:
|
if cfg.fsdp_final_state_dict_type:
|
||||||
@@ -222,18 +203,16 @@ def save_trained_model(
|
|||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||||
|
|
||||||
# Handle ReLoRA early return case
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
else:
|
else:
|
||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return model, tokenizer
|
||||||
|
|
||||||
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
|
|
||||||
# processes attempt to write the same file
|
|
||||||
if (
|
if (
|
||||||
state_dict_type == "SHARDED_STATE_DICT"
|
state_dict_type == "SHARDED_STATE_DICT"
|
||||||
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||||
@@ -265,6 +244,7 @@ def save_trained_model(
|
|||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
@@ -275,239 +255,58 @@ def save_trained_model(
|
|||||||
)
|
)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|
||||||
"""
|
|
||||||
Create a model card for the trained model if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
trainer: The trainer object with model card creation capabilities.
|
|
||||||
"""
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
# Guard since create_model_card may fail if dataset_tags is empty list
|
|
||||||
try:
|
try:
|
||||||
model_card_kwarg = {
|
model_card_kwarg = {
|
||||||
"model_name": cfg.output_dir.lstrip("./")
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
.encode("utf-8")
|
.encode("utf-8")
|
||||||
.decode("utf-8")
|
.decode("utf-8")
|
||||||
}
|
}
|
||||||
|
if cfg.datasets is not None:
|
||||||
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
|
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
|
||||||
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
|
|
||||||
if cfg.datasets is not None and not rl:
|
|
||||||
dataset_tags = [
|
dataset_tags = [
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
dataset_tags = [
|
||||||
|
d for d in dataset_tags if not d.startswith("https://")
|
||||||
|
]
|
||||||
if dataset_tags:
|
if dataset_tags:
|
||||||
|
# guard as create_model_card may fail if dataset_tags is empty list
|
||||||
|
model_card_kwarg["dataset_name"] = dataset_tags
|
||||||
|
else:
|
||||||
|
dataset_tags = [
|
||||||
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
|
dataset_tags = [
|
||||||
|
d for d in dataset_tags if not d.startswith("https://")
|
||||||
|
]
|
||||||
|
if dataset_tags:
|
||||||
|
# guard as create_model_card may fail if dataset_tags is empty list
|
||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
except (AttributeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# Defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
def save_initial_configs(
|
|
||||||
cfg: DictDefault,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
model: PreTrainedModel,
|
|
||||||
peft_config: PeftConfig | None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Save initial configurations before training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
tokenizer: The tokenizer to save.
|
|
||||||
model: The model to save configuration for.
|
|
||||||
peft_config: The PEFT configuration to save if applicable.
|
|
||||||
"""
|
|
||||||
# Create output_dir if it doesn't already exist
|
|
||||||
output_dir = Path(cfg.output_dir)
|
|
||||||
if not output_dir.is_dir():
|
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Pre-save adapter config so it's available to inspect
|
|
||||||
if peft_config:
|
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
|
|
||||||
peft_config.save_pretrained(cfg.output_dir)
|
|
||||||
|
|
||||||
# Pre-save the tokenizer and model configs
|
|
||||||
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
|
||||||
tokenizer.save_pretrained(str(output_dir))
|
|
||||||
if hasattr(model, "config"):
|
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
|
||||||
model.config.save_pretrained(str(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_card(cfg: DictDefault):
|
|
||||||
"""
|
|
||||||
Set up the Axolotl badge and add the Axolotl config to the model card if available.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
"""
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
|
||||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
|
||||||
version = importlib.metadata.version("axolotl")
|
|
||||||
if raw_axolotl_cfg.is_file():
|
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
|
||||||
|
|
||||||
|
|
||||||
def handle_untrained_tokens_fix(
|
|
||||||
cfg: DictDefault,
|
|
||||||
model: PreTrainedModel,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
train_dataset: Dataset,
|
|
||||||
safe_serialization: bool,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Apply fixes for untrained tokens if configured.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
model: The model to apply fixes to.
|
|
||||||
tokenizer: The tokenizer for token identification.
|
|
||||||
train_dataset: The training dataset to use.
|
|
||||||
safe_serialization: Whether to use safe serialization when saving.
|
|
||||||
"""
|
|
||||||
if not cfg.fix_untrained_tokens:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
|
||||||
|
|
||||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
||||||
cfg.fix_untrained_tokens, list
|
|
||||||
):
|
|
||||||
fix_untrained_tokens(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
|
||||||
model.save_pretrained(
|
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(
|
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
||||||
) -> tuple[
|
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
|
||||||
PeftModel | PreTrainedModel,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
PeftConfig | None,
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
|
||||||
trainer setup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: The configuration dictionary with training parameters.
|
|
||||||
dataset_meta: Object with training, validation datasets and metadata.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of:
|
|
||||||
- Trainer (Causal or RLHF)
|
|
||||||
- Model
|
|
||||||
- Tokenizer
|
|
||||||
- PEFT config
|
|
||||||
"""
|
|
||||||
# Load tokenizer, processor and model
|
|
||||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
|
||||||
|
|
||||||
# Set up reference model for RL if needed
|
|
||||||
model_ref = setup_reference_model(cfg, tokenizer)
|
|
||||||
|
|
||||||
# Get datasets from metadata
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
|
||||||
eval_dataset = dataset_meta.eval_dataset
|
|
||||||
total_num_steps = dataset_meta.total_num_steps
|
|
||||||
|
|
||||||
# Set up trainer
|
|
||||||
trainer = setup_trainer(
|
|
||||||
cfg=cfg,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
model_ref=model_ref,
|
|
||||||
peft_config=peft_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
trainer,
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
peft_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
||||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
|
||||||
"""
|
|
||||||
Train a model on the given dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: The configuration dictionary with training parameters
|
|
||||||
dataset_meta: Object with training, validation datasets and metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (model, tokenizer) after training
|
|
||||||
"""
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
|
||||||
(
|
|
||||||
trainer,
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
peft_config,
|
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
||||||
|
|
||||||
# Configuration for saving
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
|
||||||
handle_untrained_tokens_fix(
|
|
||||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save initial configs
|
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
|
||||||
|
|
||||||
# Set up badges and config info for model card
|
|
||||||
setup_model_card(cfg)
|
|
||||||
|
|
||||||
# Execute the training
|
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
|
||||||
|
|
||||||
# Save the trained model
|
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
|
||||||
|
|
||||||
# Create model card
|
|
||||||
create_model_card(cfg, trainer)
|
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain_hooks(_cfg, _trainer):
|
||||||
|
"""
|
||||||
|
Run hooks right before kicking off the training
|
||||||
|
:param cfg:
|
||||||
|
:param trainer:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def post_train_hooks(_cfg, _trainer):
|
||||||
|
"""
|
||||||
|
Run hooks right after training completes
|
||||||
|
:param cfg:
|
||||||
|
:param trainer:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|||||||
@@ -574,40 +574,14 @@ def prepare_opinionated_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg,
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
train_dataset,
|
|
||||||
eval_dataset,
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
processor,
|
|
||||||
total_num_steps,
|
|
||||||
model_ref=None,
|
|
||||||
peft_config=None,
|
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Helper method for instantiating and building a (causal or RLHF) trainer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Axolotl config object containing training parameters.
|
|
||||||
train_dataset: Dataset to use for training.
|
|
||||||
eval_dataset: Dataset to use for evaluation.
|
|
||||||
model: The model to train.
|
|
||||||
tokenizer: Tokenizer for processing text input.
|
|
||||||
processor: Processor for data preparation.
|
|
||||||
total_num_steps: The total number of training steps.
|
|
||||||
model_ref: Optional reference model for RLHF training. Default is None.
|
|
||||||
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
|
||||||
on the provided parameters.
|
|
||||||
"""
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model_ref
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = peft_config
|
trainer_builder.peft_config = model[2]
|
||||||
else:
|
else:
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
|
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -47,9 +47,9 @@ def download_smollm2_135m_model():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_68m_random_model():
|
def download_smollm2_135m_instruct_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("JackFram/llama-68m")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M-Instruct")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sdp_attention": True,
|
"sdp_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
@@ -72,7 +72,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sdp_attention": False,
|
"sdp_attention": False,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"flash_attn_fuse_qkv": True,
|
"flash_attn_fuse_qkv": True,
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 16384,
|
"sequence_len": 16384,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
@@ -77,8 +76,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 16384,
|
"sequence_len": 16384,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
@@ -43,6 +42,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.2,
|
"val_set_size": 0.2,
|
||||||
|
"lora_modules_to_save": ["lm_head", "embed_tokens"],
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -77,8 +76,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -124,8 +122,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -172,8 +169,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -218,8 +214,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -264,8 +259,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -314,8 +308,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ class TestLlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"sequence_len": 512,
|
"sequence_len": 512,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
|
|||||||
@@ -26,9 +26,8 @@ class TestLoadModelUtils:
|
|||||||
# load config
|
# load config
|
||||||
self.cfg = DictDefault(
|
self.cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_config": "JackFram/llama-68m",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -37,6 +36,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": ["lm_head", "embed_tokens"],
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -74,8 +73,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
def _get_base_cfg(self):
|
def _get_base_cfg(self):
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"base_model_config": "JackFram/llama-68m",
|
"base_model_config": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
|
|||||||
@@ -18,9 +18,8 @@ class TestModelsUtils:
|
|||||||
# load config
|
# load config
|
||||||
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
"model_type": "LlamaForCausalLM",
|
"model_type": "LlamaForCausalLM",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"load_in_4bit": False,
|
"load_in_4bit": False,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
|
|||||||
Reference in New Issue
Block a user