Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
3ce9b0760b fix the lora yaml for l3 2024-04-19 07:28:07 -04:00
39 changed files with 163 additions and 1018 deletions

View File

@@ -32,11 +32,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -30,11 +30,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -29,11 +29,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

1
.gitignore vendored
View File

@@ -133,7 +133,6 @@ venv/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
venv3.10/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject

View File

@@ -34,7 +34,6 @@ Features:
- [Mac](#mac) - [Mac](#mac)
- [Google Colab](#google-colab) - [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
@@ -293,42 +292,6 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
``` ```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.

View File

@@ -227,12 +227,6 @@ lora_modules_to_save:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
peft: peft:
# Configuration options for loftq initialization for LoRA # Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU. # The number of samples to include in each batch. This is the number of samples sent to each GPU.
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 4 num_epochs: 4

View File

@@ -49,7 +49,7 @@ remove_unused_columns: false
chat_template: chatml chat_template: chatml
datasets: datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned - path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla type: orpo.chat_template
``` ```
#### Using local dataset files #### Using local dataset files

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -64,4 +64,4 @@ weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>

View File

@@ -1,82 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -11,7 +11,7 @@ addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.15.0 datasets>=2.15.0
flash-attn==2.5.5 flash-attn==2.5.5
sentencepiece sentencepiece
wandb wandb
@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe fschat==0.2.36
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl==0.8.5 trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -33,7 +33,7 @@ fi
if [ "$JUPYTER_DISABLE" != "1" ]; then if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background # Run Jupyter Lab in the background
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* & jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi fi
# Execute the passed arguments (CMD) # Execute the passed arguments (CMD)

View File

@@ -264,8 +264,8 @@ def do_inference_gradio(
with torch.no_grad(): with torch.no_grad():
generation_config = GenerationConfig( generation_config = GenerationConfig(
repetition_penalty=1.1, repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), max_new_tokens=1024,
temperature=cfg.get("gradio_temperature", 0.9), temperature=0.9,
top_p=0.95, top_p=0.95,
top_k=40, top_k=40,
bos_token_id=tokenizer.bos_token_id, bos_token_id=tokenizer.bos_token_id,
@@ -300,13 +300,7 @@ def do_inference_gradio(
outputs="text", outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"), title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
) )
demo.queue().launch(show_api=False, share=True)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def choose_config(path: Path): def choose_config(path: Path):
@@ -439,23 +433,6 @@ def load_rl_datasets(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

View File

@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )
@@ -42,7 +40,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg.flash_attention = False parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -19,10 +19,7 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import register_chatml_template
register_chatml_template,
register_llama3_template,
)
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -39,22 +36,13 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml": if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
if parsed_cfg.default_system_message: LOG.info(
LOG.info( f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" )
) register_chatml_template(parsed_cfg.default_system_message)
register_chatml_template(parsed_cfg.default_system_message) else:
else: register_chatml_template()
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (
@@ -66,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -19,10 +19,7 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import register_chatml_template
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -50,15 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message: if cfg.rl and cfg.rl != "orpo":
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl import DPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -43,7 +43,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -55,7 +54,6 @@ from axolotl.utils.collators import (
MambaDataCollator, MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
@@ -213,10 +211,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "path under the model to access the layers"}, metadata={"help": "path under the model to access the layers"},
) )
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -352,8 +346,6 @@ class AxolotlTrainer(Trainer):
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
) )
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler( def _get_eval_sampler(
@@ -818,14 +810,6 @@ class AxolotlDPOTrainer(DPOTrainer):
return res return res
class AxolotlORPOTrainer(ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -889,14 +873,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
return callbacks return callbacks
@@ -942,11 +918,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
): ):
callbacks.append(SaveBetterTransformerModelCallback()) callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -993,9 +976,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer return AxolotlMambaTrainer
if self.cfg.custom_trainer_cls:
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
return importlib.import_module(_module, _cls)
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
@@ -1204,7 +1184,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
False if self.cfg.ddp else None False if self.cfg.ddp else None
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = None report_to = None
if self.cfg.use_wandb: if self.cfg.use_wandb:
report_to = "wandb" report_to = "wandb"
@@ -1425,15 +1404,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
class HFRLTrainerBuilder(TrainerBuilderBase): class HFDPOTrainerBuilder(TrainerBuilderBase):
""" """
Trainer factory class for DPO Trainer Trainer factory class for DPO Trainer
""" """
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -1469,7 +1446,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["eval_steps"] = self.cfg.eval_steps training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else: else:
training_args_kwargs["evaluation_strategy"] = "no" training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16: if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True training_args_kwargs["bf16"] = True
@@ -1521,19 +1497,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch" training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.orpo_alpha: training_args = TrainingArguments(
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = TrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps, max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
@@ -1566,34 +1530,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: dpo_trainer = AxolotlDPOTrainer(
trainer_cls = AxolotlDPOTrainer self.model,
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 self.model_ref,
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args, args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer) dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)

View File

@@ -123,25 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role, "" yield role, ""
return return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM: if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926

View File

@@ -1,133 +0,0 @@
"""
DPO strategies for llama-3 chat template
"""
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -6,4 +6,4 @@ from functools import partial
from ..base import load as load_base from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo") load = partial(load_base, module="axolotl.prompt_strategies.orpo")

View File

@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
) )
return MessageList(messages=messages) return MessageList(messages=messages)
def get_prompt(self, prompt) -> MessageList:
"""Map the data to extract everything up to the last turn"""
total_msg_len = len(prompt["chosen"])
total_msg_turns, remainder = divmod(total_msg_len, 2)
assert remainder == 0, "invalid number of turns"
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
for i in range(total_msg_turns):
if "prompt" in prompt:
messages.append(
Message(role="user", content=prompt["prompt"], label=False)
)
else:
messages.append(
Message(
role="user",
content=prompt["chosen"][i * 2]["content"],
label=False,
)
)
if i < total_msg_turns - 1:
messages.append(
Message(
role="assistant",
content=prompt["chosen"][i * 2 + 1]["content"],
label=False,
)
)
return MessageList(messages=messages)
def get_chosen(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["chosen"][-1]["content"], label=True
)
)
return res
def get_rejected(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["rejected"][-1]["content"], label=True
)
)
return res
class ORPOTokenizingStrategy(PromptTokenizingStrategy): class ORPOTokenizingStrategy(PromptTokenizingStrategy):
""" """
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
), True ), True
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
return res
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional, Type from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
name="chatml", name="chatml",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant"), roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
@@ -32,63 +32,83 @@ def register_chatml_template(system_message=None):
name="chatml_glaive", name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"), roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
) )
def register_llama3_template(system_message=None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
system_message = system_message or "You are a helpful assistant." conversation = (
register_conv_template( ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
Conversation( )
name="llama3", field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
system_message=system_message, roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
roles=("user", "assistant"), strategy = SimpleShareGPTPromptTokenizingStrategy(
sep_style=SeparatorStyle.LLAMA3, ShareGPTPrompterV2(
sep="", conversation=conversation,
stop_str="<|eot_id|>", role_key_model=field_model,
stop_token_ids=[128001, 128009], role_key_human=field_human,
) roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def build_loader( def load_guanaco(tokenizer, cfg):
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], return GuanacoShareGPTPromptTokenizingStrategy(
prompter_cls: Type["ShareGPTPrompterV2"], ShareGPTPrompterV2(),
default_conversation: Optional[str] = None, tokenizer,
): cfg.train_on_inputs,
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): cfg.sequence_len,
conversation = ( )
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -138,9 +158,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy( class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -191,16 +209,3 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -263,7 +263,6 @@ CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}", "chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>", "zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}", "vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
} }
@@ -349,10 +348,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if ( LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -3,7 +3,6 @@
import os import os
import signal import signal
import sys import sys
import weakref
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@@ -128,20 +127,14 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref): def terminate_handler(_, __, model):
if model_weakref() is not None: if cfg.flash_optimum and BetterTransformer:
_model = model_weakref() model = BetterTransformer.reverse(model)
if cfg.flash_optimum and BetterTransformer: model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
sys.exit(0) sys.exit(0)
_model_weakref = weakref.ref(model)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
) )
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)""" badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
@@ -212,10 +205,6 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -773,24 +773,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control return control
class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control

View File

@@ -24,7 +24,6 @@ def chat_templates(user_choice: str):
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
} }
if user_choice in templates: if user_choice in templates:

View File

@@ -229,8 +229,9 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "attention_mask": if feature == "attention_mask":
if self.multipack_attn: if self.multipack_attn:
arrays = [ arrays = [
(i + 1) * np.array(item) (i + 1) * np.array(item[feature])
for i, item in enumerate(features[feature]) for i, item in enumerate(features[feature])
if feature in item
] ]
else: else:
arrays = [(1) * np.array(item) for item in features[feature]] arrays = [(1) * np.array(item) for item in features[feature]]

View File

@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]: if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
LOG.warning( LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty." "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
) )
if cfg.gptq and cfg.revision_of_model: if cfg.gptq and cfg.revision_of_model:
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
raise ValueError( raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together." "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
) )
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps": if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError( raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch." "save_strategy must be empty or set to `steps` when used with saves_per_epoch."
) )
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps: if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError( raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
raise ValueError( raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
) )
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if ( if (
cfg.evaluation_strategy cfg.evaluation_strategy
and cfg.eval_steps and cfg.eval_steps

View File

@@ -143,7 +143,6 @@ class ChatTemplate(str, Enum):
inst = "inst" # pylint: disable=invalid-name inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -410,17 +409,6 @@ class WandbConfig(BaseModel):
return data return data
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
gradio_title: Optional[str] = None
gradio_share: Optional[bool] = None
gradio_server_name: Optional[str] = None
gradio_server_port: Optional[int] = None
gradio_max_new_tokens: Optional[int] = None
gradio_temperature: Optional[float] = None
# pylint: disable=too-many-public-methods,too-many-ancestors # pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig( class AxolotlInputConfig(
ModelInputConfig, ModelInputConfig,
@@ -431,7 +419,6 @@ class AxolotlInputConfig(
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
LISAConfig, LISAConfig,
GradioConfig,
RemappedParameters, RemappedParameters,
DeprecatedParameters, DeprecatedParameters,
BaseModel, BaseModel,
@@ -516,20 +503,9 @@ class AxolotlInputConfig(
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=512) sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512, metadata={"help": "maximum prompt length for RL training"}
)
sample_packing: Optional[bool] = None sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None
pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field( pretrain_multipack_attn: Optional[bool] = Field(
@@ -561,8 +537,6 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None torch_compile_backend: Optional[str] = None
custom_trainer_cls: Optional[str] = None
max_steps: Optional[int] = None max_steps: Optional[int] = None
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None warmup_ratio: Optional[float] = None
@@ -798,11 +772,11 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_push_save(cls, data): def check_push_save(cls, data):
if data.get("hub_model_id") and ( if data.get("hub_model_id") and not (
data.get("save_strategy") not in ["steps", "epoch", None] data.get("save_steps") or data.get("saves_per_epoch")
): ):
LOG.warning( LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy." "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
) )
return data return data

View File

@@ -1,11 +1,11 @@
""" """
Data processing modules Data processing modules
""" """
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -1,20 +1,17 @@
"""data handling specific to DPO""" """data handling specific to DPO"""
import inspect
import logging import logging
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, List from typing import Any, List
import yaml import yaml
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, load_dataset, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
) )
split_datasets.insert(i, ds) split_datasets.insert(i, ds)
tokenizer = None
for i, data_set in enumerate(split_datasets): for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"] _type = dataset_cfgs[i]["type"]
if _type: if _type:
if isinstance(_type, DictDefault): if isinstance(_type, DictDefault):
_type = "user_defined.default" _type = "user_defined.default"
if _cfg.rl == "orpo": ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) split_datasets[i] = data_set.map(
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(_cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
data_set = data_set.map(
ds_transform_fn, ds_transform_fn,
desc="Mapping RL Dataset", desc="Mapping RL Dataset",
) )
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
split_datasets[i] = data_set
else: else:
# If no `type` is provided, assume the dataset is already in the expected format with # If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed # "prompt", "chosen" and "rejected" already preprocessed

View File

@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path)) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"

View File

@@ -1,5 +1,4 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -505,9 +504,6 @@ def load_model(
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
@@ -997,13 +993,3 @@ def load_lora(model, cfg, inference=False, config_only=False):
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)
return model, lora_config return model, lora_config
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass

View File

@@ -1,5 +1,6 @@
"""Module for tokenization utilities""" """Module for tokenization utilities"""
import logging import logging
import re import re
from typing import Dict, List from typing import Dict, List
@@ -9,19 +10,10 @@ from termcolor import colored
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_dataset_labels( def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
dataset,
tokenizer,
num_examples=5,
text_only=False,
rl_mode=False,
):
# the dataset is already shuffled, so let's just check the first 5 elements # the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples): for idx in range(num_examples):
if not rl_mode: check_example_labels(dataset[idx], tokenizer, text_only=text_only)
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
else:
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer, text_only=False): def check_example_labels(example, tokenizer, text_only=False):
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
return " ".join(colored_tokens) return " ".join(colored_tokens)
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
"""Helper function to color tokens based on their type."""
colored_text = colored(decoded_token, color)
return (
colored_text
if text_only
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
)
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
]
return colored_tokens
def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
return delimiter.join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = { GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system", "SYSTEM": "system",

View File

@@ -1,10 +1,9 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import math import math
import os import os
import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import List, Optional from typing import List
import numpy as np import numpy as np
import torch import torch
@@ -14,7 +13,7 @@ from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -99,89 +98,17 @@ def add_position_ids(sample):
return sample return sample
def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""
input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len
if split_on_token_ids is None:
split_on_token_ids = []
if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
chunk_len = sample_len // chunks
split_indices = [i * chunk_len for i in range(1, chunks)]
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]
position_ids = []
prev_index = 0
total_skips = 0
for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips
segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)
position_ids.extend(segment_position_ids)
prev_index = split_index
sample["sequence_len"] = position_ids[-1]
position_ids = torch.tensor(position_ids)
sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)
return sample
def add_length(sample): def add_length(sample):
sample["length"] = len(sample["input_ids"]) sample["length"] = len(sample["input_ids"])
return sample return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def drop_long_seq(sample, sequence_len=2048):
return ( return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial( drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()): with zero_first(is_main_process()):
if cfg.is_preprocess: if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset)) min_input_len = np.min(get_dataset_lengths(train_dataset))
@@ -226,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Group By Length", desc="Group By Length",
) )
if cfg.use_pose: if cfg.sample_packing:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
@@ -438,8 +340,8 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: if cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]
else: else:

View File

@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
import pytest import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer) return load_model(cfg, tokenizer)
class TestHFRLTrainerBuilder: class TestHFDPOTrainerBuilder:
""" """
TestCase class for DPO trainer builder TestCase class for DPO trainer builder
""" """
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9 assert training_arguments.adam_beta2 == 0.9

View File

@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -12,12 +12,10 @@ from axolotl.prompt_strategies.sharegpt import (
GlaiveShareGPTPromptTokenizingStrategy, GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template, register_chatml_template,
register_llama3_template,
) )
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import ShareGPTPrompterV2
register_chatml_template() register_chatml_template()
register_llama3_template()
@pytest.fixture(name="sharegpt_dataset") @pytest.fixture(name="sharegpt_dataset")
@@ -117,53 +115,7 @@ def fixture_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="llama3_tokenizer") class TestSharegpt:
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
class TestSharegptLlama3:
"""Test class for ShareGPT style datasets with llama-3 prompts"""
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="llama3",
role_key_model=None,
role_key_human=None,
),
llama3_tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
271, 31724, 128009, # sys prompt, eot
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
class TestSharegptChatML:
""" """
Test class for sharegpt prompter Test class for sharegpt prompter
""" """

View File

@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset" tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name)) self.dataset.save_to_disk(tmp_ds_name)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(

View File

@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
): ):
validate_config(cfg) validate_config(cfg)
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): def test_hub_model_id_save_value_warns(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_steps(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
def test_hub_model_id_save_value(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING): with self._caplog.at_level(logging.WARNING):
validate_config(cfg) validate_config(cfg)
assert len(self._caplog.records) == 0 assert len(self._caplog.records) == 0