Compare commits
20 Commits
fsdp-qdora
...
sppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a9ac4ad27 | ||
|
|
027f7d54f0 | ||
|
|
0554105baa | ||
|
|
f58fcd09ec | ||
|
|
60fecac367 | ||
|
|
b301068098 | ||
|
|
df645906eb | ||
|
|
7fea5822f0 | ||
|
|
3367fca732 | ||
|
|
1ac899800b | ||
|
|
70185763f6 | ||
|
|
120b809465 | ||
|
|
29cf15a28c | ||
|
|
dde02fcb94 | ||
|
|
b9bb169602 | ||
|
|
601c08b4c2 | ||
|
|
cc5d31e0d9 | ||
|
|
1aeece6e24 | ||
|
|
5294653a2d | ||
|
|
98c25e15cb |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -32,6 +32,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -30,6 +30,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -86,6 +91,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -29,6 +29,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -86,6 +91,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,6 +133,7 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
venv3.10/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
@@ -227,6 +227,12 @@ lora_modules_to_save:
|
|||||||
|
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
|
# LoRA+ hyperparameters
|
||||||
|
# For more details about the following options, see:
|
||||||
|
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||||
|
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||||
|
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||||
|
|
||||||
peft:
|
peft:
|
||||||
# Configuration options for loftq initialization for LoRA
|
# Configuration options for loftq initialization for LoRA
|
||||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||||
@@ -268,6 +274,7 @@ torch_compile_backend: # Optional[str]
|
|||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
|
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
|||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
datasets:
|
datasets:
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
type: orpo.chat_template
|
type: chat_template.argilla
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
|
|||||||
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
rl: orpo
|
||||||
|
orpo_alpha: 0.1
|
||||||
|
remove_unused_columns: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
datasets:
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
|
type: chat_template.argilla
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./mistral-qlora-orpo-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.8.5
|
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
# Run Jupyter Lab in the background
|
# Run Jupyter Lab in the background
|
||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
|
|||||||
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=1024,
|
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||||
temperature=0.9,
|
temperature=cfg.get("gradio_temperature", 0.9),
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
|||||||
outputs="text",
|
outputs="text",
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
demo.queue().launch(show_api=False, share=True)
|
|
||||||
|
demo.queue().launch(
|
||||||
|
show_api=False,
|
||||||
|
share=cfg.get("gradio_share", True),
|
||||||
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
@@ -433,6 +439,23 @@ def load_rl_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cli_args.debug or cfg.debug:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[
|
||||||
|
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||||
|
for _ in range(cli_args.debug_num_examples)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
num_examples=cli_args.debug_num_examples,
|
||||||
|
text_only=cli_args.debug_text_only,
|
||||||
|
rl_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
|
SaveModelOnTrainEndCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -212,6 +213,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -347,6 +352,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
return SequentialSampler(self.train_dataset)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
@@ -882,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -927,18 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1193,6 +1201,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to = "wandb"
|
||||||
@@ -1420,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1455,6 +1466,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
@@ -1513,6 +1525,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_cls = TrainingArguments
|
training_args_cls = TrainingArguments
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
|
training_args_cls = DPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1539,6 +1555,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
|
elif self.cfg.rl == "sppo_hard":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1547,7 +1565,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
@@ -1557,6 +1575,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
|
if self.cfg.rl == "dpo":
|
||||||
|
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for mistral instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
@@ -127,14 +128,20 @@ def train(
|
|||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|
||||||
def terminate_handler(_, __, model):
|
def terminate_handler(_, __, model_weakref):
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if model_weakref() is not None:
|
||||||
model = BetterTransformer.reverse(model)
|
_model = model_weakref()
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
|
_model = BetterTransformer.reverse(_model)
|
||||||
|
_model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
_model_weakref = weakref.ref(model)
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
signal.SIGINT,
|
||||||
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
|
|||||||
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||||
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
):
|
||||||
|
control.should_save = True
|
||||||
|
return control
|
||||||
|
|||||||
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
@@ -448,10 +448,14 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
)
|
)
|
||||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
)
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
@@ -464,11 +468,6 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.evaluation_strategy
|
cfg.evaluation_strategy
|
||||||
and cfg.eval_steps
|
and cfg.eval_steps
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -409,6 +410,17 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: Optional[str] = None
|
||||||
|
gradio_share: Optional[bool] = None
|
||||||
|
gradio_server_name: Optional[str] = None
|
||||||
|
gradio_server_port: Optional[int] = None
|
||||||
|
gradio_max_new_tokens: Optional[int] = None
|
||||||
|
gradio_temperature: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||||
class AxolotlInputConfig(
|
class AxolotlInputConfig(
|
||||||
ModelInputConfig,
|
ModelInputConfig,
|
||||||
@@ -419,6 +431,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
|
GradioConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -503,9 +516,17 @@ class AxolotlInputConfig(
|
|||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
|
min_sample_len: Optional[int] = None
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
|
curriculum_sampling: Optional[bool] = None
|
||||||
|
|
||||||
|
# for PoSE context length extension
|
||||||
|
use_pose: Optional[bool] = None
|
||||||
|
pose_split_on_token_ids: Optional[List[int]] = None
|
||||||
|
pose_max_context_len: Optional[int] = None
|
||||||
|
pose_num_chunks: Optional[int] = None
|
||||||
|
|
||||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||||
pretrain_multipack_attn: Optional[bool] = Field(
|
pretrain_multipack_attn: Optional[bool] = Field(
|
||||||
@@ -554,6 +575,7 @@ class AxolotlInputConfig(
|
|||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
|
dpo_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
@@ -772,11 +794,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_save(cls, data):
|
def check_push_save(cls, data):
|
||||||
if data.get("hub_model_id") and not (
|
if data.get("hub_model_id") and (
|
||||||
data.get("save_steps") or data.get("saves_per_epoch")
|
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -789,7 +789,11 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
|
||||||
|
and not cfg.merge_lora
|
||||||
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -10,10 +9,19 @@ from termcolor import colored
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
def check_dataset_labels(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
num_examples=5,
|
||||||
|
text_only=False,
|
||||||
|
rl_mode=False,
|
||||||
|
):
|
||||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||||
for idx in range(num_examples):
|
for idx in range(num_examples):
|
||||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
if not rl_mode:
|
||||||
|
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
else:
|
||||||
|
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
|
||||||
|
|
||||||
def check_example_labels(example, tokenizer, text_only=False):
|
def check_example_labels(example, tokenizer, text_only=False):
|
||||||
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||||
|
"""Helper function to color tokens based on their type."""
|
||||||
|
colored_text = colored(decoded_token, color)
|
||||||
|
return (
|
||||||
|
colored_text
|
||||||
|
if text_only
|
||||||
|
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||||
|
"""Helper function to process and color tokens."""
|
||||||
|
colored_tokens = [
|
||||||
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
|
for token in tokenizer.encode(tokens)
|
||||||
|
]
|
||||||
|
return colored_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||||
|
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||||
|
|
||||||
|
input_tokens = example[field_prompt]
|
||||||
|
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||||
|
|
||||||
|
# Process and color each type of token
|
||||||
|
colored_tokens = process_tokens_for_rl_debug(
|
||||||
|
input_tokens, "yellow", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_chosens = process_tokens_for_rl_debug(
|
||||||
|
labels_chosen, "green", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_rejecteds = process_tokens_for_rl_debug(
|
||||||
|
labels_rejected, "red", tokenizer, text_only
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a delimiter based on text_only flag
|
||||||
|
delimiter = "" if text_only else " "
|
||||||
|
|
||||||
|
# Logging information
|
||||||
|
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||||
|
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||||
|
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||||
|
|
||||||
|
return delimiter.join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||||
"SYSTEM": "system",
|
"SYSTEM": "system",
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -98,17 +99,89 @@ def add_position_ids(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def add_pose_position_ids(
|
||||||
|
sample,
|
||||||
|
max_context_len=32768,
|
||||||
|
split_on_token_ids: Optional[List[int]] = None,
|
||||||
|
chunks: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
use the PoSE technique to extend the context length by randomly skipping
|
||||||
|
positions in the context. We only want to skip right before tokens in
|
||||||
|
the split_on_token_ids list. We should attempt to randomly distribute
|
||||||
|
the skips, but we don't need the final position_ids to be the full
|
||||||
|
context_len. There may be multiple turns in the context, so we want to
|
||||||
|
make sure we take into account the maximum possible number of skips
|
||||||
|
remaining in each sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
sample_len = len(input_ids)
|
||||||
|
max_skips = max_context_len - sample_len
|
||||||
|
|
||||||
|
if split_on_token_ids is None:
|
||||||
|
split_on_token_ids = []
|
||||||
|
|
||||||
|
if split_on_token_ids:
|
||||||
|
split_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
chunk_len = sample_len // chunks
|
||||||
|
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||||
|
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||||
|
if split_indices[0] < 2:
|
||||||
|
# drop the first split index if it's too close to the beginning
|
||||||
|
split_indices = split_indices[1:]
|
||||||
|
|
||||||
|
position_ids = []
|
||||||
|
prev_index = 0
|
||||||
|
total_skips = 0
|
||||||
|
|
||||||
|
for split_index in split_indices:
|
||||||
|
num_skips = (
|
||||||
|
random.randint(0, max_skips) # nosec B311
|
||||||
|
if prev_index != 0 and max_skips
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
max_skips -= num_skips
|
||||||
|
total_skips += num_skips
|
||||||
|
|
||||||
|
segment_position_ids = list(
|
||||||
|
range(prev_index + total_skips, split_index + total_skips)
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids.extend(segment_position_ids)
|
||||||
|
prev_index = split_index
|
||||||
|
|
||||||
|
sample["sequence_len"] = position_ids[-1]
|
||||||
|
position_ids = torch.tensor(position_ids)
|
||||||
|
|
||||||
|
sample["position_ids"] = position_ids
|
||||||
|
sample["length"] = len(position_ids)
|
||||||
|
assert len(position_ids) == len(input_ids)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def add_length(sample):
|
def add_length(sample):
|
||||||
sample["length"] = len(sample["input_ids"])
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
return (
|
||||||
|
len(sample["input_ids"]) <= sequence_len
|
||||||
|
and len(sample["input_ids"]) >= min_sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
|
)
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
if cfg.is_preprocess:
|
if cfg.is_preprocess:
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
@@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.use_pose:
|
||||||
|
pose_kwargs = {}
|
||||||
|
if cfg.pose_num_chunks is not None:
|
||||||
|
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||||
|
pose_fn = partial(
|
||||||
|
add_pose_position_ids,
|
||||||
|
max_context_len=cfg.pose_max_context_len,
|
||||||
|
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||||
|
**pose_kwargs,
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.sort("sequence_len")
|
||||||
|
if cfg.eval_sample_packing is not False:
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
elif cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
@@ -340,7 +438,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
@@ -158,3 +158,50 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_orpo_lora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 64,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"special_tokens": {},
|
||||||
|
"rl": "orpo",
|
||||||
|
"orpo_alpha": 0.1,
|
||||||
|
"remove_unused_columns": False,
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||||
|
"type": "chat_template.argilla",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "paged_adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"warmup_steps": 5,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|||||||
@@ -1067,17 +1067,51 @@ class TestValidation(BaseValidation):
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert (
|
assert len(self._caplog.records) == 1
|
||||||
"set without any models being saved" in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user