Compare commits
6 Commits
patch_lora
...
grpo-ref-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9ebff087c | ||
|
|
b53a41372f | ||
|
|
02f45e94be | ||
|
|
954e192f38 | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 |
@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
|
||||
@@ -13,7 +13,7 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.48.3
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
|
||||
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
# cleanup the ref_model if we have a peft model passed in
|
||||
# TODO remove this after next major trl release
|
||||
if (
|
||||
self.ref_model # pylint: disable=access-member-before-definition
|
||||
and is_peft_model(self.model)
|
||||
):
|
||||
del self.ref_model
|
||||
self.ref_model = None
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
|
||||
@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
|
||||
@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
dict(zip(feature_names, row))
|
||||
)
|
||||
for key, val in tokenized_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
res[key].append(val)
|
||||
|
||||
# If there are no examples left, return an empty dictionary
|
||||
if not res:
|
||||
|
||||
@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}")
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_dataset_lengths(dataset):
|
||||
if "length" in dataset.data.column_names:
|
||||
lengths = np.array(dataset.data.column("length"))
|
||||
elif "position_ids" in dataset.data.column_names:
|
||||
position_ids = dataset.data.column("position_ids")
|
||||
def get_dataset_lengths(dataset, from_arrow=False):
|
||||
if "length" in dataset.column_names:
|
||||
lengths = np.array(dataset["length"])
|
||||
elif "position_ids" in dataset.column_names:
|
||||
position_ids = dataset["position_ids"]
|
||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
if from_arrow:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
else:
|
||||
input_ids = dataset["input_ids"]
|
||||
lengths = np.array([len(seq) for seq in input_ids])
|
||||
return lengths
|
||||
|
||||
@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_smollm2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||
def fixture_mistralv03_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Tests for loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_dpo_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"rl": "dpo",
|
||||
"learning_rate": 0.000001,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"sequence_len": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestDPOChatml:
|
||||
"""
|
||||
Test loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
|
||||
def test_default(self, minimal_dpo_cfg):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "chatml",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
]
|
||||
}
|
||||
| minimal_dpo_cfg
|
||||
)
|
||||
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
assert "rejected" in train_ds[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -18,11 +19,6 @@ def fixture_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
|
||||
Reference in New Issue
Block a user