Compare commits
1 Commits
nca-pair
...
fix-l3-lor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ce9b0760b |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -32,11 +32,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -30,11 +30,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -91,11 +86,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -29,11 +29,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -91,11 +86,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,7 +133,6 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
venv3.10/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
@@ -138,7 +138,7 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||
rl:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
@@ -227,12 +227,6 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||
|
||||
peft:
|
||||
# Configuration options for loftq initialization for LoRA
|
||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
|
||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: chat_template.argilla
|
||||
type: orpo.chat_template
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: meta-llama/Meta-Llama-3-8B
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
@@ -64,4 +64,4 @@ weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
rl: orpo
|
||||
orpo_alpha: 0.1
|
||||
remove_unused_columns: false
|
||||
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: chat_template.argilla
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./mistral-qlora-orpo-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -11,7 +11,7 @@ addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets==2.15.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.5.5
|
||||
sentencepiece
|
||||
wandb
|
||||
@@ -28,7 +28,7 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
||||
fschat==0.2.36
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
@@ -39,6 +39,6 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -33,7 +33,7 @@ fi
|
||||
|
||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||
# Run Jupyter Lab in the background
|
||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
|
||||
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||
temperature=cfg.get("gradio_temperature", 0.9),
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
@@ -300,13 +300,7 @@ def do_inference_gradio(
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
@@ -439,23 +433,6 @@ def load_rl_datasets(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
@@ -43,7 +43,6 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelOnTrainEndCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -55,7 +54,6 @@ from axolotl.utils.collators import (
|
||||
MambaDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
@@ -213,10 +211,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -352,8 +346,6 @@ class AxolotlTrainer(Trainer):
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
@@ -818,14 +810,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -889,14 +873,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -942,11 +918,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -1201,7 +1184,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = None
|
||||
if self.cfg.use_wandb:
|
||||
report_to = "wandb"
|
||||
@@ -1422,15 +1404,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for DPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -1466,7 +1446,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
@@ -1518,19 +1497,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_cls = TrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
training_args_cls = DPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
training_args = training_args_cls(
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
@@ -1553,8 +1520,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
|
||||
dpo_trainer_kwargs["loss_type"] = self.cfg.rl
|
||||
elif self.cfg.rl == "kto_pair":
|
||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
@@ -1563,34 +1530,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
if self.cfg.rl == "dpo":
|
||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
dpo_trainer = AxolotlDPOTrainer(
|
||||
self.model,
|
||||
self.model_ref,
|
||||
args=training_args,
|
||||
beta=self.cfg.dpo_beta or 0.1,
|
||||
train_dataset=self.train_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.cfg.sequence_len,
|
||||
max_target_length=None,
|
||||
max_prompt_length=self.cfg.sequence_len,
|
||||
generate_during_eval=True,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||
|
||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
@@ -123,14 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<bos>" if i == 0 else ""
|
||||
message_str = message if message else ""
|
||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
DPO strategies for mistral instruct
|
||||
"""
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||
sample["chosen"] = f"{sample['chosen']}"
|
||||
sample["rejected"] = f"{sample['rejected']}"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -6,4 +6,4 @@ from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||
|
||||
@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_prompt(self, prompt) -> MessageList:
|
||||
"""Map the data to extract everything up to the last turn"""
|
||||
total_msg_len = len(prompt["chosen"])
|
||||
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
||||
assert remainder == 0, "invalid number of turns"
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
for i in range(total_msg_turns):
|
||||
if "prompt" in prompt:
|
||||
messages.append(
|
||||
Message(role="user", content=prompt["prompt"], label=False)
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=prompt["chosen"][i * 2]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
if i < total_msg_turns - 1:
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=prompt["chosen"][i * 2 + 1]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_chosen(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def get_rejected(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), True
|
||||
|
||||
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
return res
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
@@ -128,20 +127,14 @@ def train(
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
_model = model_weakref()
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
_model = BetterTransformer.reverse(_model)
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
_model_weakref = weakref.ref(model)
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||
|
||||
@@ -773,13 +773,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
|
||||
@@ -133,8 +133,6 @@ class RLType(str, Enum):
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
||||
nca_pair = "nca_pair" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
@@ -411,17 +409,6 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
gradio_title: Optional[str] = None
|
||||
gradio_share: Optional[bool] = None
|
||||
gradio_server_name: Optional[str] = None
|
||||
gradio_server_port: Optional[int] = None
|
||||
gradio_max_new_tokens: Optional[int] = None
|
||||
gradio_temperature: Optional[float] = None
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
@@ -432,7 +419,6 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
@@ -517,17 +503,9 @@ class AxolotlInputConfig(
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
sample_packing: Optional[bool] = None
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
|
||||
# for PoSE context length extension
|
||||
use_pose: Optional[bool] = None
|
||||
pose_split_on_token_ids: Optional[List[int]] = None
|
||||
pose_max_context_len: Optional[int] = None
|
||||
pose_num_chunks: Optional[int] = None
|
||||
|
||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||
pretrain_multipack_attn: Optional[bool] = Field(
|
||||
@@ -576,7 +554,6 @@ class AxolotlInputConfig(
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
orpo_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
@@ -795,11 +772,11 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_push_save(cls, data):
|
||||
if data.get("hub_model_id") and (
|
||||
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||
if data.get("hub_model_id") and not (
|
||||
data.get("save_steps") or data.get("saves_per_epoch")
|
||||
):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Data processing modules
|
||||
"""
|
||||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""data handling specific to DPO"""
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = None
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if _cfg.rl == "orpo":
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
tokenizer = load_tokenizer(_cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
split_datasets[i] = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
)
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
split_datasets[i] = data_set
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
|
||||
@@ -789,11 +789,7 @@ def load_model(
|
||||
if not reference_model or cfg.lora_model_dir:
|
||||
# if we're not loading the reference model, then we're loading the model for training
|
||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||
if (
|
||||
cfg.adapter
|
||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]
|
||||
and not cfg.merge_lora
|
||||
):
|
||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
@@ -997,13 +993,3 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
try:
|
||||
if module.weight.dtype != dtype:
|
||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
||||
module.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
@@ -9,19 +10,10 @@ from termcolor import colored
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_dataset_labels(
|
||||
dataset,
|
||||
tokenizer,
|
||||
num_examples=5,
|
||||
text_only=False,
|
||||
rl_mode=False,
|
||||
):
|
||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(num_examples):
|
||||
if not rl_mode:
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
else:
|
||||
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
|
||||
|
||||
def check_example_labels(example, tokenizer, text_only=False):
|
||||
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||
"""Helper function to color tokens based on their type."""
|
||||
colored_text = colored(decoded_token, color)
|
||||
return (
|
||||
colored_text
|
||||
if text_only
|
||||
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||
)
|
||||
|
||||
|
||||
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||
"""Helper function to process and color tokens."""
|
||||
colored_tokens = [
|
||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||
for token in tokenizer.encode(tokens)
|
||||
]
|
||||
return colored_tokens
|
||||
|
||||
|
||||
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||
|
||||
input_tokens = example[field_prompt]
|
||||
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||
|
||||
# Process and color each type of token
|
||||
colored_tokens = process_tokens_for_rl_debug(
|
||||
input_tokens, "yellow", tokenizer, text_only
|
||||
)
|
||||
colored_chosens = process_tokens_for_rl_debug(
|
||||
labels_chosen, "green", tokenizer, text_only
|
||||
)
|
||||
colored_rejecteds = process_tokens_for_rl_debug(
|
||||
labels_rejected, "red", tokenizer, text_only
|
||||
)
|
||||
|
||||
# Create a delimiter based on text_only flag
|
||||
delimiter = "" if text_only else " "
|
||||
|
||||
# Logging information
|
||||
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
|
||||
return delimiter.join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -14,7 +13,7 @@ from datasets import set_caching_enabled
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -99,89 +98,17 @@ def add_position_ids(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def add_pose_position_ids(
|
||||
sample,
|
||||
max_context_len=32768,
|
||||
split_on_token_ids: Optional[List[int]] = None,
|
||||
chunks: int = 2,
|
||||
):
|
||||
"""
|
||||
use the PoSE technique to extend the context length by randomly skipping
|
||||
positions in the context. We only want to skip right before tokens in
|
||||
the split_on_token_ids list. We should attempt to randomly distribute
|
||||
the skips, but we don't need the final position_ids to be the full
|
||||
context_len. There may be multiple turns in the context, so we want to
|
||||
make sure we take into account the maximum possible number of skips
|
||||
remaining in each sample.
|
||||
"""
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
sample_len = len(input_ids)
|
||||
max_skips = max_context_len - sample_len
|
||||
|
||||
if split_on_token_ids is None:
|
||||
split_on_token_ids = []
|
||||
|
||||
if split_on_token_ids:
|
||||
split_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||
]
|
||||
else:
|
||||
chunk_len = sample_len // chunks
|
||||
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||
if split_indices[0] < 2:
|
||||
# drop the first split index if it's too close to the beginning
|
||||
split_indices = split_indices[1:]
|
||||
|
||||
position_ids = []
|
||||
prev_index = 0
|
||||
total_skips = 0
|
||||
|
||||
for split_index in split_indices:
|
||||
num_skips = (
|
||||
random.randint(0, max_skips) # nosec B311
|
||||
if prev_index != 0 and max_skips
|
||||
else 0
|
||||
)
|
||||
max_skips -= num_skips
|
||||
total_skips += num_skips
|
||||
|
||||
segment_position_ids = list(
|
||||
range(prev_index + total_skips, split_index + total_skips)
|
||||
)
|
||||
|
||||
position_ids.extend(segment_position_ids)
|
||||
prev_index = split_index
|
||||
|
||||
sample["sequence_len"] = position_ids[-1]
|
||||
position_ids = torch.tensor(position_ids)
|
||||
|
||||
sample["position_ids"] = position_ids
|
||||
sample["length"] = len(position_ids)
|
||||
assert len(position_ids) == len(input_ids)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return (
|
||||
len(sample["input_ids"]) <= sequence_len
|
||||
and len(sample["input_ids"]) >= min_sequence_len
|
||||
)
|
||||
def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
with zero_first(is_main_process()):
|
||||
if cfg.is_preprocess:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
@@ -226,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
desc="Group By Length",
|
||||
)
|
||||
|
||||
if cfg.use_pose:
|
||||
pose_kwargs = {}
|
||||
if cfg.pose_num_chunks is not None:
|
||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||
pose_fn = partial(
|
||||
add_pose_position_ids,
|
||||
max_context_len=cfg.pose_max_context_len,
|
||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||
**pose_kwargs,
|
||||
)
|
||||
train_dataset = train_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
train_dataset = train_dataset.sort("sequence_len")
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
num_proc=cfg.dataset_processes,
|
||||
@@ -438,8 +340,8 @@ def prepare_optim_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
class TestHFDPOTrainerBuilder:
|
||||
"""
|
||||
TestCase class for DPO trainer builder
|
||||
"""
|
||||
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
|
||||
@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "orpo",
|
||||
"orpo_alpha": 0.1,
|
||||
"remove_unused_columns": False,
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||
"type": "chat_template.argilla",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||
self.dataset.save_to_disk(tmp_ds_name)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert (
|
||||
"set without any models being saved" in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_hub_model_id_save_value(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
Reference in New Issue
Block a user