Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdd8be7097 | ||
|
|
08143c7b0d | ||
|
|
e1915f5625 | ||
|
|
844331005c | ||
|
|
61aa291119 | ||
|
|
b98d7d7098 | ||
|
|
d7eea2ff34 |
@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
|||||||
hi there!. goodbye farewell</s>
|
hi there!. goodbye farewell</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
We can check that the right tokens are ingored by comparing the labels
|
We can check that the right tokens are ignored by comparing the labels
|
||||||
to each token:
|
to each token:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
28
docs/multimodal.qmd
Normal file
28
docs/multimodal.qmd
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# MultiModal / Vision Language Models (BETA)
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
|
||||||
|
- Mllama, i.e. llama with vision models
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
||||||
|
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
|
||||||
|
chat_template: llama3_2_vision
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
# only finetune the Language model, leave the vision model and vision tower frozen
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
```
|
||||||
63
examples/llama-3-vision/lora-11b.yaml
Normal file
63
examples/llama-3-vision/lora-11b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: llama3_2_vision
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.12.0
|
peft==0.13.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992
|
transformers==4.45.1
|
||||||
tokenizers>=0.19.1
|
tokenizers>=0.19.1
|
||||||
bitsandbytes==0.43.3
|
bitsandbytes==0.44.0
|
||||||
accelerate==0.34.2
|
accelerate==0.34.2
|
||||||
datasets==2.21.0
|
datasets==2.21.0
|
||||||
deepspeed==0.14.4
|
deepspeed==0.14.4
|
||||||
@@ -34,7 +34,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.2.1
|
liger-kernel==0.3.0
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
@@ -39,7 +40,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
@@ -234,7 +235,8 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
default_tokens: Dict[str, str] = {}
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
for token, symbol in default_tokens.items():
|
||||||
# If the token isn't already specified in the config, add it
|
# If the token isn't already specified in the config, add it
|
||||||
@@ -242,10 +244,13 @@ def do_inference_gradio(
|
|||||||
tokenizer.add_special_tokens({token: symbol})
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
chat_template_str = None
|
||||||
if prompter:
|
if prompter:
|
||||||
prompter_module = getattr(
|
prompter_module = getattr(
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
elif cfg.chat_template:
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -259,7 +264,24 @@ def do_inference_gradio(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
if chat_template_str:
|
||||||
|
batch = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -282,6 +304,7 @@ def do_inference_gradio(
|
|||||||
streamer = TextIteratorStreamer(tokenizer)
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"inputs": batch["input_ids"].to(cfg.device),
|
"inputs": batch["input_ids"].to(cfg.device),
|
||||||
|
"attention_mask": batch["attention_mask"].to(cfg.device),
|
||||||
"generation_config": generation_config,
|
"generation_config": generation_config,
|
||||||
"streamer": streamer,
|
"streamer": streamer,
|
||||||
}
|
}
|
||||||
@@ -407,9 +430,12 @@ def load_datasets(
|
|||||||
cli_args: TrainerCliArgs,
|
cli_args: TrainerCliArgs,
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||||
cfg, tokenizer
|
cfg,
|
||||||
|
tokenizer,
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
@@ -45,7 +46,6 @@ from trl import (
|
|||||||
)
|
)
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_mlflow_available
|
from axolotl.utils import is_mlflow_available
|
||||||
@@ -61,12 +61,14 @@ from axolotl.utils.callbacks import (
|
|||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
@@ -250,6 +252,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
chat_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Chat template converting chat messages to text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -456,14 +462,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
loraplus_lr_embedding = getattr(
|
loraplus_lr_embedding = getattr(
|
||||||
self.args, "loraplus_lr_embedding", None
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
)
|
)
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
opt_model,
|
opt_model,
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
loraplus_lr_ratio,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
loraplus_lr_embedding,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
@@ -969,9 +975,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
opt_model,
|
opt_model,
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
loraplus_lr_ratio,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
loraplus_lr_embedding,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -1043,10 +1049,11 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
_model_ref = None
|
_model_ref = None
|
||||||
_peft_config = None
|
_peft_config = None
|
||||||
|
|
||||||
def __init__(self, cfg, model, tokenizer):
|
def __init__(self, cfg, model, tokenizer, processor=None):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
# in case the model supports tagging, add the axolotl tag.
|
# in case the model supports tagging, add the axolotl tag.
|
||||||
# This makes sure the tag is correctly pushed even if a user calls
|
# This makes sure the tag is correctly pushed even if a user calls
|
||||||
@@ -1417,6 +1424,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
report_to = []
|
report_to = []
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to.append("wandb")
|
report_to.append("wandb")
|
||||||
|
if self.cfg.wandb_name:
|
||||||
|
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow:
|
||||||
report_to.append("mlflow")
|
report_to.append("mlflow")
|
||||||
if self.cfg.use_tensorboard:
|
if self.cfg.use_tensorboard:
|
||||||
@@ -1513,6 +1522,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
|
if self.cfg.chat_template:
|
||||||
|
training_arguments_kwargs["chat_template"] = chat_templates(
|
||||||
|
self.cfg.chat_template
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||||
@@ -1574,6 +1587,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_args = self.hook_post_create_training_args(training_args)
|
training_args = self.hook_post_create_training_args(training_args)
|
||||||
|
|
||||||
|
# unset run_name so wandb sets up experiment names
|
||||||
|
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||||
|
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
@@ -1653,7 +1672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
if self.cfg.processor_type and self.processor:
|
||||||
|
collator = MultiModalChatDataCollator
|
||||||
|
kwargs["processor"] = self.processor
|
||||||
|
kwargs["chat_template"] = training_args.chat_template
|
||||||
|
else:
|
||||||
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""Module for LoRA+"""
|
|
||||||
|
|
||||||
# MIT License
|
|
||||||
#
|
|
||||||
# Copyright (c) 2024 nikhil-ghosh-berkeley
|
|
||||||
# https://github.com/nikhil-ghosh-berkeley/loraplus
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
from peft.tuners import lora
|
|
||||||
from torch import nn
|
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.loraplus")
|
|
||||||
|
|
||||||
|
|
||||||
def get_module(name, opt_model):
|
|
||||||
"""
|
|
||||||
Retrieve a module from a model using its parameter name.
|
|
||||||
Args:
|
|
||||||
name (str): Full name of the parameter, typically including module path.
|
|
||||||
opt_model (torch.nn.Module): The model from which to retrieve the module.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Module corresponding to the given name.
|
|
||||||
"""
|
|
||||||
parent_idx = 2 if "lora" in name else 1
|
|
||||||
module_names = name.split(sep=".")[:-parent_idx]
|
|
||||||
module = reduce(getattr, module_names, opt_model)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def create_loraplus_optimizer(
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
opt_model (torch.nn.Module): The model for which the optimizer is being created.
|
|
||||||
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
|
|
||||||
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
|
|
||||||
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
|
|
||||||
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
|
|
||||||
|
|
||||||
if loraplus_lr_embedding is None:
|
|
||||||
loraplus_lr_embedding = 1e-6
|
|
||||||
|
|
||||||
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
param_groups = {
|
|
||||||
"groupA": {},
|
|
||||||
"groupB": {},
|
|
||||||
"groupB_no_decay": {},
|
|
||||||
"embedding": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
|
|
||||||
module = get_module(name, opt_model)
|
|
||||||
if isinstance(module, lora.Embedding):
|
|
||||||
param_groups["embedding"][name] = param
|
|
||||||
elif "lora_B" in name or param.ndim == 1:
|
|
||||||
if name in decay_parameters:
|
|
||||||
param_groups["groupB"][name] = param
|
|
||||||
else:
|
|
||||||
param_groups["groupB_no_decay"][name] = param
|
|
||||||
else:
|
|
||||||
param_groups["groupA"][name] = param
|
|
||||||
|
|
||||||
assigned_param_groups = ""
|
|
||||||
for group, group_params in param_groups.items():
|
|
||||||
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
|
|
||||||
LOG.info(assigned_param_groups)
|
|
||||||
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupA"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": lr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["embedding"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": loraplus_lr_embedding,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupB"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": lr * loraplus_lr_ratio,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupB_no_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr * loraplus_lr_ratio,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
229
src/axolotl/monkeypatch/attention/mllama.py
Normal file
229
src/axolotl/monkeypatch/attention/mllama.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch for Vision Llama for FA2 support
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
||||||
|
from transformers.models.mllama.modeling_mllama import (
|
||||||
|
MllamaTextCrossAttention,
|
||||||
|
MllamaTextSelfAttention,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
|
||||||
|
"""
|
||||||
|
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
|
||||||
|
implements the forward pass using Flash Attention for improved performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Check if flash attention version is greater or equal to 2.1
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
attention_mask: Optional[ # pylint: disable=unused-argument
|
||||||
|
torch.Tensor
|
||||||
|
] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False, # pylint: disable=unused-argument
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
key_states = self.k_proj(cross_attention_states)
|
||||||
|
value_states = self.v_proj(cross_attention_states)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, -1, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, -1, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
key_states = self.k_norm(key_states)
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
self.layer_idx,
|
||||||
|
{"cache_position": cache_position},
|
||||||
|
)
|
||||||
|
elif cache_position[0] != 0:
|
||||||
|
key_states, value_states = (
|
||||||
|
past_key_value.key_cache[self.layer_idx],
|
||||||
|
past_key_value.value_cache[self.layer_idx],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transpose to get the expected layout for flash attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply Flash Attention
|
||||||
|
dropout_rate = self.dropout if self.training else 0.0
|
||||||
|
output = flash_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
return_attn_probs=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = output.contiguous().view(bsz, q_len, -1)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
|
||||||
|
"""
|
||||||
|
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
|
||||||
|
implements the forward pass using Flash Attention for improved performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
|
||||||
|
super().__init__(config, layer_idx, *args, **kwargs)
|
||||||
|
|
||||||
|
# Check if flash attention version is greater or equal to 2.1
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False, # pylint: disable=unused-argument
|
||||||
|
past_key_value=None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x num_heads x head_dim
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# Transpose to get the expected layout for flash attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# Handle potential silent casting to float32
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = (
|
||||||
|
self.config._pre_quantization_dtype # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mllama():
|
||||||
|
from transformers.models.mllama.modeling_mllama import (
|
||||||
|
MLLAMA_TEXT_ATTENTION_CLASSES,
|
||||||
|
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
|
||||||
|
MLLAMA_VISION_ATTENTION_CLASSES,
|
||||||
|
MllamaPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
|
||||||
|
True
|
||||||
|
)
|
||||||
|
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
||||||
|
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
|
||||||
|
"flash_attention_2"
|
||||||
|
] = MllamaTextCrossFlashAttention2
|
||||||
|
# fallback to SDPA
|
||||||
|
MLLAMA_VISION_ATTENTION_CLASSES[
|
||||||
|
"flash_attention_2"
|
||||||
|
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
||||||
@@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"mllama_text_model",
|
||||||
"llama",
|
"llama",
|
||||||
"mistral",
|
"mistral",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
# This code is based off the following work:
|
# This code is based off the following work:
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
""" PyTorch StableLM Epoch model. """
|
""" PyTorch StableLM Epoch model. """
|
||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
|||||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
LOG = logging.getLogger("axolotl.prompt_strategies")
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||||
try:
|
try:
|
||||||
load_fn = "load"
|
load_fn = "load"
|
||||||
if strategy.split(".")[-1].startswith("load_"):
|
if strategy.split(".")[-1].startswith("load_"):
|
||||||
@@ -24,6 +24,8 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
|||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
if "ds_cfg" in sig.parameters:
|
if "ds_cfg" in sig.parameters:
|
||||||
load_kwargs["ds_cfg"] = ds_cfg
|
load_kwargs["ds_cfg"] = ds_cfg
|
||||||
|
if "processor" in sig.parameters:
|
||||||
|
load_kwargs["processor"] = processor
|
||||||
return func(tokenizer, cfg, **load_kwargs)
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ HF Chat Templates prompt strategy
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
@@ -20,6 +22,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
processor=None,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
@@ -44,11 +47,12 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.message_field_training = message_field_training
|
self.message_field_training = message_field_training
|
||||||
self.message_field_training_detail = message_field_training_detail
|
self.message_field_training_detail = message_field_training_detail
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.processor: ProcessorMixin = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||||
turns = [
|
turns = [
|
||||||
{
|
{
|
||||||
"role": self.roles[t[self.message_field_role]],
|
"role": self.roles[t[self.message_field_role]],
|
||||||
@@ -61,6 +65,28 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
if self.drop_system_message and turns[0]["role"] == "system":
|
if self.drop_system_message and turns[0]["role"] == "system":
|
||||||
turns = turns[1:]
|
turns = turns[1:]
|
||||||
|
|
||||||
|
if self.processor:
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
turns,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
batch = self.processor(
|
||||||
|
text=text,
|
||||||
|
images=images,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
)
|
||||||
|
# workaround since processor works in batches instead of single examples
|
||||||
|
for k, val in batch.items():
|
||||||
|
if k in ["pixel_values"]:
|
||||||
|
batch[k] = val.tolist()
|
||||||
|
else:
|
||||||
|
batch[k] = val.squeeze().tolist()
|
||||||
|
return batch
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@@ -191,6 +217,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||||
self.train_on_eos = train_on_eos
|
self.train_on_eos = train_on_eos
|
||||||
|
self.images = "images"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self):
|
def messages(self):
|
||||||
@@ -209,10 +236,21 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
and not self.prompter.message_field_training_detail
|
and not self.prompter.message_field_training_detail
|
||||||
):
|
):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
|
images = self.get_images(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt(
|
prompt_ids = self.prompter.build_prompt(
|
||||||
turns[:-1], add_generation_prompt=True
|
turns[:-1],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
images=images,
|
||||||
)
|
)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
tokenized_res = self.prompter.build_prompt(turns, images=images)
|
||||||
|
tokenized_prompt = {}
|
||||||
|
if isinstance(tokenized_res, list):
|
||||||
|
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
||||||
|
tokenized_prompt["input_ids"] = input_ids
|
||||||
|
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
||||||
|
else:
|
||||||
|
input_ids = tokenized_res["input_ids"]
|
||||||
|
tokenized_prompt = tokenized_res
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt_len = len(prompt_ids)
|
user_prompt_len = len(prompt_ids)
|
||||||
@@ -220,17 +258,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
labels = input_ids
|
labels = input_ids
|
||||||
|
|
||||||
tokenized_prompt = {
|
tokenized_prompt["labels"] = labels
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
"attention_mask": [1] * len(input_ids),
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
LOG.info(self.roles_to_train)
|
|
||||||
LOG.info(self.train_on_eos)
|
|
||||||
LOG.info(self.prompter.message_field_training)
|
|
||||||
LOG.info(self.prompter.message_field_training_detail)
|
|
||||||
|
|
||||||
turns = prompt[self.messages]
|
turns = prompt[self.messages]
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
@@ -368,15 +398,18 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
return prompt[self.messages]
|
||||||
|
|
||||||
|
def get_images(self, prompt):
|
||||||
|
return prompt.get(self.images, None)
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||||
ds_cfg = ds_cfg or {}
|
ds_cfg = ds_cfg or {}
|
||||||
|
|
||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training_detail": ds_cfg.get(
|
||||||
"message_field_training_detail",
|
"message_field_training_detail",
|
||||||
@@ -386,6 +419,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
"max_length": cfg.sequence_len + 1,
|
"max_length": cfg.sequence_len + 1,
|
||||||
|
"processor": processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy_params = {
|
strategy_params = {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from axolotl.core.tokenizer_utils import fix_untrained_tokens
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -69,6 +69,9 @@ def train(
|
|||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
processor = None
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
eval_dataset = dataset_meta.eval_dataset
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
@@ -96,7 +99,9 @@ def train(
|
|||||||
LOG.debug(msg)
|
LOG.debug(msg)
|
||||||
# we wait unitl the last possible moment to setup Accelerator
|
# we wait unitl the last possible moment to setup Accelerator
|
||||||
Accelerator()
|
Accelerator()
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(
|
||||||
|
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||||
|
)
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
model_ref = None
|
model_ref = None
|
||||||
@@ -122,6 +127,7 @@ def train(
|
|||||||
eval_dataset,
|
eval_dataset,
|
||||||
(model, model_ref, peft_config),
|
(model, model_ref, peft_config),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
processor,
|
||||||
total_num_steps,
|
total_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
10
src/axolotl/utils/collators/__init__.py
Normal file
10
src/axolotl/utils/collators/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
shared axolotl collators for multipack, mamba, multimodal
|
||||||
|
"""
|
||||||
|
from .batching import ( # noqa: F401
|
||||||
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
)
|
||||||
|
from .mamba import MambaDataCollator # noqa: F401
|
||||||
@@ -1,17 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Sequence, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -183,34 +180,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MambaDataCollator:
|
|
||||||
"""
|
|
||||||
Collator for State Space Models (Mamba)
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer
|
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
||||||
input_ids, labels = tuple(
|
|
||||||
[torch.LongTensor(instance[key]) for instance in instances]
|
|
||||||
for key in ("input_ids", "labels")
|
|
||||||
)
|
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
input_ids,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
4
src/axolotl/utils/collators/core.py
Normal file
4
src/axolotl/utils/collators/core.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
"""
|
||||||
|
basic shared collator constants
|
||||||
|
"""
|
||||||
|
IGNORE_INDEX = -100
|
||||||
38
src/axolotl/utils/collators/mamba.py
Normal file
38
src/axolotl/utils/collators/mamba.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
collators for Mamba
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaDataCollator:
|
||||||
|
"""
|
||||||
|
Collator for State Space Models (Mamba)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple(
|
||||||
|
[torch.LongTensor(instance[key]) for instance in instances]
|
||||||
|
for key in ("input_ids", "labels")
|
||||||
|
)
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
179
src/axolotl/utils/collators/mm_chat.py
Normal file
179
src/axolotl/utils/collators/mm_chat.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
Collators for multi-modal chat messages and packing
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||||
|
"""
|
||||||
|
Collator for multi-modal chat messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
processor: ProcessorMixin
|
||||||
|
return_tensors: str = "pt"
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
packing: bool = False
|
||||||
|
sequence_length: Optional[int] = None
|
||||||
|
max_images: int = -1
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.packing:
|
||||||
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
|
def torch_call(
|
||||||
|
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
if self.packing:
|
||||||
|
return self.__class__.process_rows_packing(
|
||||||
|
examples,
|
||||||
|
self.processor,
|
||||||
|
self.chat_template,
|
||||||
|
self.max_images,
|
||||||
|
self.sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__.process_rows(
|
||||||
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows_packing(
|
||||||
|
examples,
|
||||||
|
processor,
|
||||||
|
chat_template,
|
||||||
|
max_images,
|
||||||
|
sequence_length,
|
||||||
|
length_only=False,
|
||||||
|
):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Perform sample packing within a batch
|
||||||
|
|
||||||
|
if processor.tokenizer.sep_token is None:
|
||||||
|
sep_token = "[SEP]"
|
||||||
|
processor.tokenizer.add_tokens([sep_token])
|
||||||
|
processor.tokenizer.sep_token = sep_token
|
||||||
|
sep_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.tokenizer.sep_token
|
||||||
|
)
|
||||||
|
pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
images = [example["images"] for example in examples]
|
||||||
|
|
||||||
|
if max_images > 0:
|
||||||
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|
||||||
|
batch = processor(text=texts, images=images, padding=False)
|
||||||
|
|
||||||
|
n_sequence = len(examples)
|
||||||
|
n_seq_in_batch = 0
|
||||||
|
pack_len = 0
|
||||||
|
features_pack = {}
|
||||||
|
packed = {}
|
||||||
|
features = list[batch.keys()]
|
||||||
|
for feature in features:
|
||||||
|
features_pack[feature] = []
|
||||||
|
packed[feature] = []
|
||||||
|
features.remove("input_ids")
|
||||||
|
|
||||||
|
for seq_in_batch_id in range(n_sequence):
|
||||||
|
next_seq_len = len(batch["input_ids"][seq_in_batch_id])
|
||||||
|
if not pack_len + next_seq_len + 1 < sequence_length:
|
||||||
|
n_seq_in_batch += 1
|
||||||
|
pack_len += next_seq_len + 1
|
||||||
|
features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
|
||||||
|
sep_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
Do something with attention mask and cross-attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
features_pack[feature] += batch[feature][seq_in_batch_id]
|
||||||
|
|
||||||
|
else:
|
||||||
|
for _ in range(sequence_length - pack_len):
|
||||||
|
features_pack["input_ids"] += [pad_token_id]
|
||||||
|
|
||||||
|
packed["input_ids"].append(
|
||||||
|
torch.tensor(features_pack["input_ids"].copy())
|
||||||
|
)
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
||||||
|
features_pack[feature] = []
|
||||||
|
|
||||||
|
pack_len = 0
|
||||||
|
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
|
labels = [pack.clone() for pack in packed["input_ids"]]
|
||||||
|
for label_id, label in enumerate(labels):
|
||||||
|
labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
|
||||||
|
labels[label_id][label == image_token_id] = -100
|
||||||
|
packed["labels"] = labels
|
||||||
|
|
||||||
|
if length_only:
|
||||||
|
return {
|
||||||
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
|
return packed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
|
# use this as a starting point
|
||||||
|
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
images = [example["images"] for example in examples]
|
||||||
|
|
||||||
|
if max_images > 0:
|
||||||
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
if length_only:
|
||||||
|
return {
|
||||||
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
|
return batch
|
||||||
@@ -121,15 +121,36 @@ def normalize_config(cfg):
|
|||||||
cfg.base_model_config = cfg.base_model
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
cfg.model_config_type = model_config.model_type
|
|
||||||
|
|
||||||
cfg.tokenizer_config = (
|
cfg.tokenizer_config = (
|
||||||
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg.is_multimodal = (
|
||||||
|
hasattr(model_config, "model_type")
|
||||||
|
and model_config.model_type in ["llava", "mllama"]
|
||||||
|
or any(
|
||||||
|
multimodal_name in cfg.base_model.lower()
|
||||||
|
for multimodal_name in [
|
||||||
|
"pixtral",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
or cfg.is_multimodal
|
||||||
|
)
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
cfg.processor_config = (
|
||||||
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
|
)
|
||||||
|
model_config = model_config.text_config
|
||||||
|
|
||||||
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
# figure out if the model is llama
|
# figure out if the model is llama
|
||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
(
|
||||||
|
hasattr(model_config, "model_type")
|
||||||
|
and model_config.model_type == ["llama", "mllama_text_model"]
|
||||||
|
)
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model.lower()
|
or "llama" in cfg.base_model.lower()
|
||||||
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ class ChatTemplate(str, Enum):
|
|||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
@@ -228,11 +229,12 @@ class LoraConfig(BaseModel):
|
|||||||
lora_r: Optional[int] = None
|
lora_r: Optional[int] = None
|
||||||
lora_alpha: Optional[int] = None
|
lora_alpha: Optional[int] = None
|
||||||
lora_fan_in_fan_out: Optional[bool] = None
|
lora_fan_in_fan_out: Optional[bool] = None
|
||||||
lora_target_modules: Optional[List[str]] = None
|
lora_target_modules: Optional[Union[str, List[str]]] = None
|
||||||
lora_target_linear: Optional[bool] = None
|
lora_target_linear: Optional[bool] = None
|
||||||
lora_modules_to_save: Optional[List[str]] = None
|
lora_modules_to_save: Optional[List[str]] = None
|
||||||
lora_dropout: Optional[float] = 0.0
|
lora_dropout: Optional[float] = 0.0
|
||||||
peft_layers_to_transform: Optional[List[int]] = None
|
peft_layers_to_transform: Optional[List[int]] = None
|
||||||
|
peft_layers_pattern: Optional[List[str]] = None
|
||||||
peft: Optional[PeftConfig] = None
|
peft: Optional[PeftConfig] = None
|
||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
@@ -298,6 +300,13 @@ class LoraConfig(BaseModel):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
@classmethod
|
||||||
|
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||||
|
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||||
|
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||||
|
return loraplus_lr_embedding
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
@@ -321,6 +330,9 @@ class ModelInputConfig(BaseModel):
|
|||||||
tokenizer_type: Optional[str] = Field(
|
tokenizer_type: Optional[str] = Field(
|
||||||
default=None, metadata={"help": "transformers tokenizer class"}
|
default=None, metadata={"help": "transformers tokenizer class"}
|
||||||
)
|
)
|
||||||
|
processor_type: Optional[str] = Field(
|
||||||
|
default=None, metadata={"help": "transformers processor class"}
|
||||||
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
@@ -523,6 +535,7 @@ class AxolotlInputConfig(
|
|||||||
dataset_prepared_path: Optional[str] = None
|
dataset_prepared_path: Optional[str] = None
|
||||||
dataset_shard_num: Optional[int] = None
|
dataset_shard_num: Optional[int] = None
|
||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
|
skip_prepare_dataset: Optional[bool] = False
|
||||||
|
|
||||||
pretraining_dataset: Optional[ # type: ignore
|
pretraining_dataset: Optional[ # type: ignore
|
||||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||||
@@ -990,6 +1003,18 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_mm_prepare(cls, data):
|
||||||
|
if data.get("skip_prepare_dataset"):
|
||||||
|
if data.get("remove_unused_columns") is None:
|
||||||
|
LOG.info(
|
||||||
|
"setting `remove_unused_columns: false` for skip_prepare_dataset"
|
||||||
|
)
|
||||||
|
data["remove_unused_columns"] = False
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_warmup(cls, data):
|
def check_warmup(cls, data):
|
||||||
@@ -1017,12 +1042,20 @@ class AxolotlInputConfig(
|
|||||||
return neftune_noise_alpha
|
return neftune_noise_alpha
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check(self):
|
def check_rl_beta(self):
|
||||||
if self.dpo_beta and not self.rl_beta:
|
if self.dpo_beta and not self.rl_beta:
|
||||||
self.rl_beta = self.dpo_beta
|
self.rl_beta = self.dpo_beta
|
||||||
del self.dpo_beta
|
del self.dpo_beta
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_simpo_warmup(self):
|
||||||
|
if self.rl == "simpo" and self.warmup_ratio:
|
||||||
|
raise ValueError(
|
||||||
|
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_frozen(cls, data):
|
def check_frozen(cls, data):
|
||||||
@@ -1037,6 +1070,15 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_peft_layers_pattern(cls, data):
|
||||||
|
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
|
||||||
|
raise ValueError(
|
||||||
|
"peft_layers_pattern requires peft_layers_to_transform to be set"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fft_possible_bad_config(self):
|
def check_fft_possible_bad_config(self):
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -51,20 +51,31 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_local_main_process()):
|
with zero_first(is_local_main_process()):
|
||||||
if cfg.test_datasets:
|
if cfg.test_datasets:
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
|
split="train",
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
_, eval_dataset, _ = load_prepare_datasets(
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
|
split="test",
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
@@ -123,6 +134,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
split="train",
|
split="train",
|
||||||
|
processor=None,
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||||
tokenizer_name = cfg.tokenizer_config
|
tokenizer_name = cfg.tokenizer_config
|
||||||
@@ -180,6 +192,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
cfg.dataset_prepared_path
|
cfg.dataset_prepared_path
|
||||||
and any(prepared_ds_path.glob("*"))
|
and any(prepared_ds_path.glob("*"))
|
||||||
and not cfg.is_preprocess
|
and not cfg.is_preprocess
|
||||||
|
and not cfg.skip_prepare_dataset
|
||||||
):
|
):
|
||||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
@@ -423,12 +436,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
dataset=ds,
|
dataset=ds,
|
||||||
d_base_type=d_base_type,
|
d_base_type=d_base_type,
|
||||||
d_prompt_style=d_prompt_style,
|
d_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
datasets.append(dataset_wrapper)
|
datasets.append(dataset_wrapper)
|
||||||
prompters.append(dataset_prompter)
|
prompters.append(dataset_prompter)
|
||||||
|
|
||||||
LOG.info("merging datasets")
|
if len(datasets) == 1:
|
||||||
dataset = concatenate_datasets(datasets)
|
dataset = datasets[0]
|
||||||
|
else:
|
||||||
|
LOG.info("merging datasets")
|
||||||
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
@@ -437,9 +454,10 @@ def load_tokenized_prepared_datasets(
|
|||||||
else:
|
else:
|
||||||
LOG.debug("NOT shuffling merged datasets")
|
LOG.debug("NOT shuffling merged datasets")
|
||||||
|
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
if not cfg.skip_prepare_dataset:
|
||||||
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
@@ -478,9 +496,14 @@ def load_prepare_datasets(
|
|||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
split="train",
|
split="train",
|
||||||
|
processor=None,
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path, split=split
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
default_dataset_prepared_path,
|
||||||
|
split=split,
|
||||||
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||||
@@ -546,6 +569,7 @@ def get_dataset_wrapper(
|
|||||||
d_base_type,
|
d_base_type,
|
||||||
dataset,
|
dataset,
|
||||||
d_prompt_style=None,
|
d_prompt_style=None,
|
||||||
|
processor=None,
|
||||||
):
|
):
|
||||||
dataset_wrapper = None
|
dataset_wrapper = None
|
||||||
dataset_prompter = None
|
dataset_prompter = None
|
||||||
@@ -578,7 +602,11 @@ def get_dataset_wrapper(
|
|||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
elif cfg.skip_prepare_dataset:
|
||||||
|
dataset_wrapper = dataset
|
||||||
|
elif ds_strategy := load(
|
||||||
|
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
||||||
|
):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
|
|||||||
@@ -28,12 +28,17 @@ from transformers import ( # noqa: F401
|
|||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
@@ -80,6 +85,9 @@ def get_module_class_from_name(module, name):
|
|||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
model_config = model_config.text_config
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
and model_config.quantization_config
|
and model_config.quantization_config
|
||||||
@@ -299,11 +307,31 @@ def load_tokenizer(cfg):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||||
|
|
||||||
|
processor_cls = AutoProcessor
|
||||||
|
if cfg.processor_type:
|
||||||
|
processor_cls = getattr(transformers, cfg.processor_type)
|
||||||
|
|
||||||
|
processor = processor_cls.from_pretrained(
|
||||||
|
cfg.processor_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
*,
|
||||||
|
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
@@ -319,12 +347,23 @@ def load_model(
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(cfg)
|
plugin_manager.pre_model_load(cfg)
|
||||||
|
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
text_model_config = model_config.text_config
|
||||||
|
else:
|
||||||
|
text_model_config = model_config
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
|
|
||||||
if cfg.gradient_checkpointing == "unsloth":
|
if cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
|
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
|
||||||
|
if cfg.flash_attention:
|
||||||
|
from axolotl.monkeypatch.attention.mllama import patch_mllama
|
||||||
|
|
||||||
|
patch_mllama()
|
||||||
|
|
||||||
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||||
@@ -461,6 +500,19 @@ def load_model(
|
|||||||
max_memory = cfg.max_memory
|
max_memory = cfg.max_memory
|
||||||
device_map = cfg.device_map
|
device_map = cfg.device_map
|
||||||
|
|
||||||
|
AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
if model_config.model_type == "llava":
|
||||||
|
AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
|
LlavaForConditionalGeneration
|
||||||
|
)
|
||||||
|
elif model_config.model_type == "mllama":
|
||||||
|
AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
|
MllamaForConditionalGeneration
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||||
|
|
||||||
if cfg.gpu_memory_limit:
|
if cfg.gpu_memory_limit:
|
||||||
gpu_memory_limit = (
|
gpu_memory_limit = (
|
||||||
str(cfg.gpu_memory_limit) + "GiB"
|
str(cfg.gpu_memory_limit) + "GiB"
|
||||||
@@ -478,7 +530,7 @@ def load_model(
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = AutoModelForCausalLM.from_config(
|
model_canvas = AutoModelLoader.from_config(
|
||||||
model_config, trust_remote_code=cfg.trust_remote_code or False
|
model_config, trust_remote_code=cfg.trust_remote_code or False
|
||||||
)
|
)
|
||||||
model_canvas.tie_weights()
|
model_canvas.tie_weights()
|
||||||
@@ -633,6 +685,8 @@ def load_model(
|
|||||||
quantization_config = (
|
quantization_config = (
|
||||||
quantization_config or model_kwargs["quantization_config"]
|
quantization_config or model_kwargs["quantization_config"]
|
||||||
)
|
)
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
model = load_sharded_model_quant(
|
model = load_sharded_model_quant(
|
||||||
base_model,
|
base_model,
|
||||||
model_config,
|
model_config,
|
||||||
@@ -651,7 +705,9 @@ def load_model(
|
|||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
|
model = AutoModelLoader.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -690,13 +746,17 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
|
model = AutoModelLoader.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
@@ -707,21 +767,23 @@ def load_model(
|
|||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
# when training starts
|
# when training starts
|
||||||
if (
|
if (
|
||||||
hasattr(model_config, "max_seq_len")
|
hasattr(text_model_config, "max_seq_len")
|
||||||
and model_config.max_seq_len
|
and text_model_config.max_seq_len
|
||||||
and cfg.sequence_len > model_config.max_seq_len
|
and cfg.sequence_len > model_config.max_seq_len
|
||||||
):
|
):
|
||||||
model_config.max_seq_len = cfg.sequence_len
|
text_model_config.max_seq_len = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
elif (
|
elif (
|
||||||
hasattr(model_config, "max_sequence_length")
|
hasattr(text_model_config, "max_sequence_length")
|
||||||
and model_config.max_sequence_length
|
and text_model_config.max_sequence_length
|
||||||
and cfg.sequence_len > model_config.max_sequence_length
|
and cfg.sequence_len > text_model_config.max_sequence_length
|
||||||
):
|
):
|
||||||
model_config.max_sequence_length = cfg.sequence_len
|
text_model_config.max_sequence_length = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
|
model = AutoModelLoader.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
@@ -734,7 +796,9 @@ def load_model(
|
|||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.is_multimodal:
|
||||||
|
model_config.text_config = text_model_config
|
||||||
|
model = AutoModelLoader.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
@@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
lora_target_modules = cfg.lora_target_modules or []
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
lora_target_modules_as_list = (
|
||||||
|
lora_target_modules
|
||||||
|
if isinstance(lora_target_modules, list)
|
||||||
|
else [lora_target_modules]
|
||||||
|
)
|
||||||
|
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||||
|
|
||||||
lora_config_kwargs = {}
|
lora_config_kwargs = {}
|
||||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||||
@@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=lora_target_modules,
|
target_modules=lora_target_modules,
|
||||||
layers_to_transform=cfg.peft_layers_to_transform,
|
layers_to_transform=cfg.peft_layers_to_transform,
|
||||||
|
layers_pattern=cfg.peft_layers_pattern,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ def process_pretraining_datasets_for_packing(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
train_dataset.data.column("input_ids")
|
train_dataset.data.column("input_ids")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
@@ -319,7 +319,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
|
|
||||||
skip_estimates = cfg.model_config_type == "mamba"
|
skip_estimates = cfg.model_config_type == "mamba"
|
||||||
|
|
||||||
if not skip_estimates and not cfg.total_supervised_tokens:
|
if (
|
||||||
|
not skip_estimates
|
||||||
|
and not cfg.total_supervised_tokens
|
||||||
|
and not cfg.skip_prepare_dataset
|
||||||
|
):
|
||||||
total_supervised_tokens = (
|
total_supervised_tokens = (
|
||||||
train_dataset.data.column("labels")
|
train_dataset.data.column("labels")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
@@ -478,13 +482,15 @@ def prepare_opinionated_env(cfg):
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(
|
||||||
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
|
):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
else:
|
else:
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
|
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
chat_template=chat_templates("llama3"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
roles={
|
roles={
|
||||||
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
phi35_tokenizer,
|
phi35_tokenizer,
|
||||||
chat_templates("phi_35"),
|
chat_template=chat_templates("phi_35"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
roles={
|
roles={
|
||||||
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
chat_template=chat_templates("llama3"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
message_field_training="training",
|
message_field_training="training",
|
||||||
@@ -227,8 +227,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
|
|
||||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
@@ -277,8 +280,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
|
|
||||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
@@ -327,8 +333,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
|
|
||||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_inputs=True")
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -77,7 +79,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_inputs=False")
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -118,7 +122,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing roles_to_train with assistant only")
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -144,7 +150,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing roles_to_train with all roles")
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -175,7 +183,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with empty roles_to_train")
|
LOG.info("Testing with empty roles_to_train")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -194,7 +204,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='all'")
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -219,7 +231,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='turn'")
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -267,7 +281,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='last'")
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -298,7 +314,9 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='none'")
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||||
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -324,7 +342,9 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with drop_system_message=True")
|
LOG.info("Testing with drop_system_message=True")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
llama3_tokenizer,
|
||||||
|
chat_template=chat_templates("llama3"),
|
||||||
|
drop_system_message=True,
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -350,7 +370,9 @@ class TestChatTemplateConfigurations:
|
|||||||
}
|
}
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
llama3_tokenizer,
|
||||||
|
chat_template=chat_templates("llama3"),
|
||||||
|
roles=custom_roles,
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -402,7 +424,7 @@ class TestChatTemplateConfigurations:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
chat_template=chat_templates("llama3"),
|
||||||
message_field_training="train",
|
message_field_training="train",
|
||||||
message_field_training_detail="train_detail",
|
message_field_training_detail="train_detail",
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user