Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
7a7c56f018 fixes to support fsdp-qdora 2024-04-23 08:37:04 -04:00
24 changed files with 82 additions and 520 deletions

View File

@@ -32,11 +32,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3

View File

@@ -30,11 +30,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -29,11 +29,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

1
.gitignore vendored
View File

@@ -133,7 +133,6 @@ venv/
ENV/
env.bak/
venv.bak/
venv3.10/
# Spyder project settings
.spyderproject

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
# use RL training: 'dpo', 'ipo', 'kto_pair'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -227,12 +227,6 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
peft:
# Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
micro_batch_size: 2
eval_batch_size:
num_epochs: 4

View File

@@ -49,7 +49,7 @@ remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
type: orpo.chat_template
```
#### Using local dataset files

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora
lora_model_dir:
lora_r: 8

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora
lora_model_dir:
lora_r: 8

View File

@@ -1,82 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
trl==0.8.5
zstandard==0.22.0
fastcore

View File

@@ -33,7 +33,7 @@ fi
if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi
# Execute the passed arguments (CMD)

View File

@@ -264,8 +264,8 @@ def do_inference_gradio(
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
@@ -300,13 +300,7 @@ def do_inference_gradio(
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path):
@@ -439,23 +433,6 @@ def load_rl_datasets(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,

View File

@@ -30,7 +30,7 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -43,7 +43,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -213,10 +212,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
class AxolotlTrainer(Trainer):
@@ -352,8 +347,6 @@ class AxolotlTrainer(Trainer):
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
@@ -889,14 +882,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -942,11 +927,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
):
callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -1201,7 +1193,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = None
if self.cfg.use_wandb:
report_to = "wandb"
@@ -1429,8 +1420,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -1466,7 +1455,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
@@ -1525,10 +1513,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = TrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1553,8 +1537,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
dpo_trainer_kwargs["loss_type"] = self.cfg.rl
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
@@ -1563,7 +1547,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
@@ -1573,8 +1557,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]

View File

@@ -1,30 +0,0 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -3,7 +3,6 @@
import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
@@ -128,20 +127,14 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
_model_weakref = weakref.ref(model)
signal.signal(
signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""

View File

@@ -773,13 +773,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control

View File

@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.revision_of_model:
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps

View File

@@ -133,8 +133,6 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
nca_pair = "nca_pair" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -190,6 +188,7 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_fsdp_alt_loader: Optional[bool] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
@@ -411,17 +410,6 @@ class WandbConfig(BaseModel):
return data
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
gradio_title: Optional[str] = None
gradio_share: Optional[bool] = None
gradio_server_name: Optional[str] = None
gradio_server_port: Optional[int] = None
gradio_max_new_tokens: Optional[int] = None
gradio_temperature: Optional[float] = None
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
@@ -432,7 +420,6 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
@@ -517,17 +504,9 @@ class AxolotlInputConfig(
unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None
pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
@@ -576,7 +555,6 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
@@ -795,11 +773,11 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and (
data.get("save_strategy") not in ["steps", "epoch", None]
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
return data

View File

@@ -70,6 +70,7 @@ def load_and_quantize(
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
is_dora: bool = False,
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
@@ -108,6 +109,12 @@ def load_and_quantize(
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
setattr(
submodule,
"dora_scale",
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
)
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
@@ -177,6 +184,7 @@ def load_sharded_model_quant(
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
@@ -249,6 +257,7 @@ def load_sharded_model_quant(
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
is_dora=cfg.peft_use_dora,
)
if cfg.local_rank == 0 and verbose:

View File

@@ -34,6 +34,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.quantizers import AutoHfQuantizer
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
@@ -568,7 +569,7 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and cfg.qlora_fsdp_alt_loader
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
@@ -577,6 +578,11 @@ def load_model(
cfg,
quant_storage=quant_storage,
)
if model_kwargs["quantization_config"]:
hf_quantizer = AutoHfQuantizer.from_config(
model_kwargs["quantization_config"]
)
model.hf_quantizer = hf_quantizer
skip_move_to_device = True
elif (
model_config.model_type == "llama"
@@ -789,11 +795,7 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]
and not cfg.merge_lora
):
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
@@ -1007,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16):
module.to(dtype)
except AttributeError:
pass
for name, param in model.named_parameters():
try:
if param.data.dtype != dtype:
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
param.data = param.data.to(dtype)
except AttributeError:
pass

View File

@@ -1,5 +1,6 @@
"""Module for tokenization utilities"""
import logging
import re
from typing import Dict, List
@@ -9,19 +10,10 @@ from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(
dataset,
tokenizer,
num_examples=5,
text_only=False,
rl_mode=False,
):
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
if not rl_mode:
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
else:
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer, text_only=False):
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
return " ".join(colored_tokens)
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
"""Helper function to color tokens based on their type."""
colored_text = colored(decoded_token, color)
return (
colored_text
if text_only
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
)
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
]
return colored_tokens
def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
return delimiter.join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",

View File

@@ -1,10 +1,9 @@
"""Module containing the Trainer class and related functions"""
import math
import os
import random
from contextlib import contextmanager
from functools import partial
from typing import List, Optional
from typing import List
import numpy as np
import torch
@@ -99,89 +98,17 @@ def add_position_ids(sample):
return sample
def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""
input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len
if split_on_token_ids is None:
split_on_token_ids = []
if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
chunk_len = sample_len // chunks
split_indices = [i * chunk_len for i in range(1, chunks)]
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]
position_ids = []
prev_index = 0
total_skips = 0
for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips
segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)
position_ids.extend(segment_position_ids)
prev_index = split_index
sample["sequence_len"] = position_ids[-1]
position_ids = torch.tensor(position_ids)
sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)
return sample
def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
@@ -226,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
if cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
@@ -438,7 +340,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]:
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
):
validate_config(cfg)
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_steps(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
def test_hub_model_id_save_value_warns(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
def test_hub_model_id_save_value(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0