Compare commits
1 Commits
sppo
...
fix-l3-lor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ce9b0760b |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -32,11 +32,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -30,11 +30,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -91,11 +86,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -29,11 +29,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -91,11 +86,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,7 +133,6 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
venv3.10/
|
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
|
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
@@ -227,12 +227,6 @@ lora_modules_to_save:
|
|||||||
|
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
# LoRA+ hyperparameters
|
|
||||||
# For more details about the following options, see:
|
|
||||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
|
||||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
|
||||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
|
||||||
|
|
||||||
peft:
|
peft:
|
||||||
# Configuration options for loftq initialization for LoRA
|
# Configuration options for loftq initialization for LoRA
|
||||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||||
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
|
|||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
|||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
datasets:
|
datasets:
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
type: chat_template.argilla
|
type: orpo.chat_template
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
@@ -64,4 +64,4 @@ weight_decay: 0.0
|
|||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
base_model: mistralai/Mistral-7B-v0.1
|
|
||||||
model_type: MistralForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
rl: orpo
|
|
||||||
orpo_alpha: 0.1
|
|
||||||
remove_unused_columns: false
|
|
||||||
|
|
||||||
chat_template: chatml
|
|
||||||
datasets:
|
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
||||||
type: chat_template.argilla
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./mistral-qlora-orpo-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -11,7 +11,7 @@ addict
|
|||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.15.0
|
datasets>=2.15.0
|
||||||
flash-attn==2.5.5
|
flash-attn==2.5.5
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
fschat==0.2.36
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
# Run Jupyter Lab in the background
|
# Run Jupyter Lab in the background
|
||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
|
|||||||
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
max_new_tokens=1024,
|
||||||
temperature=cfg.get("gradio_temperature", 0.9),
|
temperature=0.9,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
@@ -300,13 +300,7 @@ def do_inference_gradio(
|
|||||||
outputs="text",
|
outputs="text",
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
|
demo.queue().launch(show_api=False, share=True)
|
||||||
demo.queue().launch(
|
|
||||||
show_api=False,
|
|
||||||
share=cfg.get("gradio_share", True),
|
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
@@ -439,23 +433,6 @@ def load_rl_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
|
||||||
LOG.info("check_dataset_labels...")
|
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
check_dataset_labels(
|
|
||||||
train_dataset.select(
|
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
num_examples=cli_args.debug_num_examples,
|
|
||||||
text_only=cli_args.debug_text_only,
|
|
||||||
rl_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -43,7 +43,6 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelOnTrainEndCallback,
|
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -55,7 +54,6 @@ from axolotl.utils.collators import (
|
|||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.models import ensure_dtype
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
@@ -213,10 +211,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
curriculum_sampling: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -352,8 +346,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
return SequentialSampler(self.train_dataset)
|
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
@@ -818,14 +810,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -889,14 +873,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -942,11 +918,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1201,7 +1184,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to = "wandb"
|
||||||
@@ -1422,15 +1404,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Trainer factory class for DPO Trainer
|
Trainer factory class for DPO Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1466,7 +1446,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
@@ -1518,19 +1497,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
if self.cfg.orpo_alpha:
|
training_args = TrainingArguments(
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
|
||||||
|
|
||||||
training_args_cls = TrainingArguments
|
|
||||||
if self.cfg.rl == "orpo":
|
|
||||||
training_args_cls = ORPOConfig
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
|
||||||
training_args_cls = DPOConfig
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
|
|
||||||
training_args = training_args_cls(
|
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
@@ -1555,8 +1522,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
elif self.cfg.rl == "sppo_hard":
|
|
||||||
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1565,34 +1530,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
dpo_trainer = AxolotlDPOTrainer(
|
||||||
trainer_cls = AxolotlDPOTrainer
|
self.model,
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
self.model_ref,
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
|
||||||
if self.cfg.rl == "dpo":
|
|
||||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
elif self.cfg.rl == "orpo":
|
|
||||||
trainer_cls = AxolotlORPOTrainer
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
|
||||||
dpo_trainer = trainer_cls(
|
|
||||||
*trainer_cls_args,
|
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
beta=self.cfg.dpo_beta or 0.1,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
|
max_length=self.cfg.sequence_len,
|
||||||
|
max_target_length=None,
|
||||||
|
max_prompt_length=self.cfg.sequence_len,
|
||||||
|
generate_during_eval=True,
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
if self.cfg.fsdp:
|
|
||||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
|
||||||
|
|
||||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|||||||
@@ -123,14 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.GEMMA:
|
|
||||||
if self.system_message:
|
|
||||||
raise ValueError("Gemma chat template does not support system messages")
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
prefix = "<bos>" if i == 0 else ""
|
|
||||||
message_str = message if message else ""
|
|
||||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
"""
|
|
||||||
DPO strategies for mistral instruct
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def argilla_chat(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
for argilla/dpo-mix-7k conversations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
@@ -6,4 +6,4 @@ from functools import partial
|
|||||||
|
|
||||||
from ..base import load as load_base
|
from ..base import load as load_base
|
||||||
|
|
||||||
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||||
|
|||||||
@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
|
|||||||
)
|
)
|
||||||
return MessageList(messages=messages)
|
return MessageList(messages=messages)
|
||||||
|
|
||||||
def get_prompt(self, prompt) -> MessageList:
|
|
||||||
"""Map the data to extract everything up to the last turn"""
|
|
||||||
total_msg_len = len(prompt["chosen"])
|
|
||||||
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
|
||||||
assert remainder == 0, "invalid number of turns"
|
|
||||||
|
|
||||||
messages: List[Message] = []
|
|
||||||
if system := prompt.get("system", None):
|
|
||||||
messages.append(Message(role="system", content=system, label=False))
|
|
||||||
for i in range(total_msg_turns):
|
|
||||||
if "prompt" in prompt:
|
|
||||||
messages.append(
|
|
||||||
Message(role="user", content=prompt["prompt"], label=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
messages.append(
|
|
||||||
Message(
|
|
||||||
role="user",
|
|
||||||
content=prompt["chosen"][i * 2]["content"],
|
|
||||||
label=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if i < total_msg_turns - 1:
|
|
||||||
messages.append(
|
|
||||||
Message(
|
|
||||||
role="assistant",
|
|
||||||
content=prompt["chosen"][i * 2 + 1]["content"],
|
|
||||||
label=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return MessageList(messages=messages)
|
|
||||||
|
|
||||||
def get_chosen(self, prompt) -> MessageList:
|
|
||||||
res = self.get_prompt(prompt)
|
|
||||||
res.messages.append(
|
|
||||||
Message(
|
|
||||||
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def get_rejected(self, prompt) -> MessageList:
|
|
||||||
res = self.get_prompt(prompt)
|
|
||||||
res.messages.append(
|
|
||||||
Message(
|
|
||||||
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
), True
|
), True
|
||||||
|
|
||||||
|
|
||||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
dataset_parser = ORPODatasetParsingStrategy()
|
|
||||||
|
|
||||||
chat_template_str = chat_templates(cfg.chat_template)
|
|
||||||
|
|
||||||
def transform_fn(sample, tokenizer=None):
|
|
||||||
res = {}
|
|
||||||
|
|
||||||
res["prompt"] = tokenizer.apply_chat_template(
|
|
||||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
|
||||||
add_generation_prompt=True,
|
|
||||||
chat_template=chat_template_str,
|
|
||||||
tokenize=False,
|
|
||||||
)
|
|
||||||
prompt_str_len = len(res["prompt"])
|
|
||||||
res["chosen"] = tokenizer.apply_chat_template(
|
|
||||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
|
||||||
add_generation_prompt=False,
|
|
||||||
chat_template=chat_template_str,
|
|
||||||
tokenize=False,
|
|
||||||
)[prompt_str_len:]
|
|
||||||
res["rejected"] = tokenizer.apply_chat_template(
|
|
||||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
|
||||||
add_generation_prompt=False,
|
|
||||||
chat_template=chat_template_str,
|
|
||||||
tokenize=False,
|
|
||||||
)[prompt_str_len:]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import weakref
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
@@ -128,20 +127,14 @@ def train(
|
|||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|
||||||
def terminate_handler(_, __, model_weakref):
|
def terminate_handler(_, __, model):
|
||||||
if model_weakref() is not None:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
_model = model_weakref()
|
model = BetterTransformer.reverse(model)
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
_model = BetterTransformer.reverse(_model)
|
|
||||||
_model.save_pretrained(
|
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
_model_weakref = weakref.ref(model)
|
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT,
|
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
|
|||||||
@@ -773,13 +773,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
|
||||||
"""Callback to save model on train end"""
|
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
|
||||||
self, args, state, control, **kwargs
|
|
||||||
):
|
|
||||||
control.should_save = True
|
|
||||||
return control
|
|
||||||
|
|||||||
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
)
|
)
|
||||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
)
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.evaluation_strategy
|
cfg.evaluation_strategy
|
||||||
and cfg.eval_steps
|
and cfg.eval_steps
|
||||||
|
|||||||
@@ -133,7 +133,6 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -410,17 +409,6 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class GradioConfig(BaseModel):
|
|
||||||
"""Gradio configuration subset"""
|
|
||||||
|
|
||||||
gradio_title: Optional[str] = None
|
|
||||||
gradio_share: Optional[bool] = None
|
|
||||||
gradio_server_name: Optional[str] = None
|
|
||||||
gradio_server_port: Optional[int] = None
|
|
||||||
gradio_max_new_tokens: Optional[int] = None
|
|
||||||
gradio_temperature: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||||
class AxolotlInputConfig(
|
class AxolotlInputConfig(
|
||||||
ModelInputConfig,
|
ModelInputConfig,
|
||||||
@@ -431,7 +419,6 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -516,17 +503,9 @@ class AxolotlInputConfig(
|
|||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
min_sample_len: Optional[int] = None
|
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
curriculum_sampling: Optional[bool] = None
|
|
||||||
|
|
||||||
# for PoSE context length extension
|
|
||||||
use_pose: Optional[bool] = None
|
|
||||||
pose_split_on_token_ids: Optional[List[int]] = None
|
|
||||||
pose_max_context_len: Optional[int] = None
|
|
||||||
pose_num_chunks: Optional[int] = None
|
|
||||||
|
|
||||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||||
pretrain_multipack_attn: Optional[bool] = Field(
|
pretrain_multipack_attn: Optional[bool] = Field(
|
||||||
@@ -575,7 +554,6 @@ class AxolotlInputConfig(
|
|||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
dpo_beta: Optional[float] = None
|
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
@@ -794,11 +772,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_save(cls, data):
|
def check_push_save(cls, data):
|
||||||
if data.get("hub_model_id") and (
|
if data.get("hub_model_id") and not (
|
||||||
data.get("save_strategy") not in ["steps", "epoch", None]
|
data.get("save_steps") or data.get("saves_per_epoch")
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Data processing modules
|
Data processing modules
|
||||||
"""
|
"""
|
||||||
|
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
||||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
|
||||||
from axolotl.utils.data.sft import ( # noqa: F401
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
load_prepare_datasets,
|
load_prepare_datasets,
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.models import load_tokenizer
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
)
|
)
|
||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
tokenizer = None
|
|
||||||
for i, data_set in enumerate(split_datasets):
|
for i, data_set in enumerate(split_datasets):
|
||||||
_type = dataset_cfgs[i]["type"]
|
_type = dataset_cfgs[i]["type"]
|
||||||
if _type:
|
if _type:
|
||||||
if isinstance(_type, DictDefault):
|
if isinstance(_type, DictDefault):
|
||||||
_type = "user_defined.default"
|
_type = "user_defined.default"
|
||||||
if _cfg.rl == "orpo":
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
split_datasets[i] = data_set.map(
|
||||||
else:
|
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
|
||||||
sig = inspect.signature(ds_transform_fn)
|
|
||||||
if "tokenizer" in sig.parameters:
|
|
||||||
if not tokenizer:
|
|
||||||
tokenizer = load_tokenizer(_cfg)
|
|
||||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
data_set = data_set.map(
|
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
)
|
)
|
||||||
if isinstance(data_set, DatasetDict):
|
|
||||||
data_set = data_set["train"]
|
|
||||||
split_datasets[i] = data_set
|
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
|
|||||||
@@ -789,11 +789,7 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if (
|
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
||||||
cfg.adapter
|
|
||||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
|
|
||||||
and not cfg.merge_lora
|
|
||||||
):
|
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
@@ -997,13 +993,3 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
try:
|
|
||||||
if module.weight.dtype != dtype:
|
|
||||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
|
||||||
module.to(dtype)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -9,19 +10,10 @@ from termcolor import colored
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(
|
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
||||||
dataset,
|
|
||||||
tokenizer,
|
|
||||||
num_examples=5,
|
|
||||||
text_only=False,
|
|
||||||
rl_mode=False,
|
|
||||||
):
|
|
||||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||||
for idx in range(num_examples):
|
for idx in range(num_examples):
|
||||||
if not rl_mode:
|
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
|
||||||
else:
|
|
||||||
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
|
||||||
|
|
||||||
|
|
||||||
def check_example_labels(example, tokenizer, text_only=False):
|
def check_example_labels(example, tokenizer, text_only=False):
|
||||||
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
|
||||||
"""Helper function to color tokens based on their type."""
|
|
||||||
colored_text = colored(decoded_token, color)
|
|
||||||
return (
|
|
||||||
colored_text
|
|
||||||
if text_only
|
|
||||||
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
|
||||||
"""Helper function to process and color tokens."""
|
|
||||||
colored_tokens = [
|
|
||||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
|
||||||
for token in tokenizer.encode(tokens)
|
|
||||||
]
|
|
||||||
return colored_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def check_rl_example_labels(example, tokenizer, text_only=False):
|
|
||||||
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
|
||||||
|
|
||||||
input_tokens = example[field_prompt]
|
|
||||||
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
|
||||||
|
|
||||||
# Process and color each type of token
|
|
||||||
colored_tokens = process_tokens_for_rl_debug(
|
|
||||||
input_tokens, "yellow", tokenizer, text_only
|
|
||||||
)
|
|
||||||
colored_chosens = process_tokens_for_rl_debug(
|
|
||||||
labels_chosen, "green", tokenizer, text_only
|
|
||||||
)
|
|
||||||
colored_rejecteds = process_tokens_for_rl_debug(
|
|
||||||
labels_rejected, "red", tokenizer, text_only
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a delimiter based on text_only flag
|
|
||||||
delimiter = "" if text_only else " "
|
|
||||||
|
|
||||||
# Logging information
|
|
||||||
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
|
||||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
|
||||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
|
||||||
|
|
||||||
return delimiter.join(colored_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||||
"SYSTEM": "system",
|
"SYSTEM": "system",
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -14,7 +13,7 @@ from datasets import set_caching_enabled
|
|||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -99,89 +98,17 @@ def add_position_ids(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def add_pose_position_ids(
|
|
||||||
sample,
|
|
||||||
max_context_len=32768,
|
|
||||||
split_on_token_ids: Optional[List[int]] = None,
|
|
||||||
chunks: int = 2,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
use the PoSE technique to extend the context length by randomly skipping
|
|
||||||
positions in the context. We only want to skip right before tokens in
|
|
||||||
the split_on_token_ids list. We should attempt to randomly distribute
|
|
||||||
the skips, but we don't need the final position_ids to be the full
|
|
||||||
context_len. There may be multiple turns in the context, so we want to
|
|
||||||
make sure we take into account the maximum possible number of skips
|
|
||||||
remaining in each sample.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
|
||||||
sample_len = len(input_ids)
|
|
||||||
max_skips = max_context_len - sample_len
|
|
||||||
|
|
||||||
if split_on_token_ids is None:
|
|
||||||
split_on_token_ids = []
|
|
||||||
|
|
||||||
if split_on_token_ids:
|
|
||||||
split_indices = [
|
|
||||||
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
chunk_len = sample_len // chunks
|
|
||||||
split_indices = [i * chunk_len for i in range(1, chunks)]
|
|
||||||
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
|
||||||
if split_indices[0] < 2:
|
|
||||||
# drop the first split index if it's too close to the beginning
|
|
||||||
split_indices = split_indices[1:]
|
|
||||||
|
|
||||||
position_ids = []
|
|
||||||
prev_index = 0
|
|
||||||
total_skips = 0
|
|
||||||
|
|
||||||
for split_index in split_indices:
|
|
||||||
num_skips = (
|
|
||||||
random.randint(0, max_skips) # nosec B311
|
|
||||||
if prev_index != 0 and max_skips
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
max_skips -= num_skips
|
|
||||||
total_skips += num_skips
|
|
||||||
|
|
||||||
segment_position_ids = list(
|
|
||||||
range(prev_index + total_skips, split_index + total_skips)
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids.extend(segment_position_ids)
|
|
||||||
prev_index = split_index
|
|
||||||
|
|
||||||
sample["sequence_len"] = position_ids[-1]
|
|
||||||
position_ids = torch.tensor(position_ids)
|
|
||||||
|
|
||||||
sample["position_ids"] = position_ids
|
|
||||||
sample["length"] = len(position_ids)
|
|
||||||
assert len(position_ids) == len(input_ids)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def add_length(sample):
|
def add_length(sample):
|
||||||
sample["length"] = len(sample["input_ids"])
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
def drop_long_seq(sample, sequence_len=2048):
|
||||||
return (
|
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
||||||
len(sample["input_ids"]) <= sequence_len
|
|
||||||
and len(sample["input_ids"]) >= min_sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
drop_long_seq,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len or 2,
|
|
||||||
)
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
if cfg.is_preprocess:
|
if cfg.is_preprocess:
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
@@ -226,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.use_pose:
|
if cfg.sample_packing:
|
||||||
pose_kwargs = {}
|
|
||||||
if cfg.pose_num_chunks is not None:
|
|
||||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
|
||||||
pose_fn = partial(
|
|
||||||
add_pose_position_ids,
|
|
||||||
max_context_len=cfg.pose_max_context_len,
|
|
||||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
|
||||||
**pose_kwargs,
|
|
||||||
)
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
pose_fn,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Add position_id column (PoSE)",
|
|
||||||
)
|
|
||||||
train_dataset = train_dataset.sort("sequence_len")
|
|
||||||
if cfg.eval_sample_packing is not False:
|
|
||||||
if eval_dataset:
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
pose_fn,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Add position_id column (PoSE)",
|
|
||||||
)
|
|
||||||
elif cfg.sample_packing:
|
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
@@ -438,8 +340,8 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
|||||||
return load_model(cfg, tokenizer)
|
return load_model(cfg, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
class TestHFRLTrainerBuilder:
|
class TestHFDPOTrainerBuilder:
|
||||||
"""
|
"""
|
||||||
TestCase class for DPO trainer builder
|
TestCase class for DPO trainer builder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.998
|
assert training_arguments.adam_beta1 == 0.998
|
||||||
assert training_arguments.adam_beta2 == 0.9
|
assert training_arguments.adam_beta2 == 0.9
|
||||||
|
|||||||
@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_orpo_lora(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 64,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.1,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"special_tokens": {},
|
|
||||||
"rl": "orpo",
|
|
||||||
"orpo_alpha": 0.1,
|
|
||||||
"remove_unused_columns": False,
|
|
||||||
"chat_template": "chatml",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
|
||||||
"type": "chat_template.argilla",
|
|
||||||
"split": "train",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "paged_adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 20,
|
|
||||||
"save_steps": 10,
|
|
||||||
"warmup_steps": 5,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
self.dataset.save_to_disk(tmp_ds_name)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert len(self._caplog.records) == 1
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert len(self._caplog.records) == 1
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert len(self._caplog.records) == 0
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert len(self._caplog.records) == 0
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
|
||||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert len(self._caplog.records) == 0
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
|
||||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert (
|
||||||
|
"set without any models being saved" in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert len(self._caplog.records) == 0
|
assert len(self._caplog.records) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user