Compare commits
6 Commits
970b2a6f2f
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b5a9d1d88 | ||
|
|
eb59070040 | ||
|
|
9722aaf7d8 | ||
|
|
c5d20bbd79 | ||
|
|
7fbedbd300 | ||
|
|
145ffc9be1 |
@@ -2,21 +2,21 @@
|
|||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.49.1
|
bitsandbytes==0.49.1
|
||||||
triton>=3.0.0
|
triton>=3.4.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
liger-kernel==0.6.4
|
liger-kernel==0.7.0
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==26.0
|
packaging==26.0
|
||||||
huggingface_hub>=1.1.7
|
huggingface_hub>=1.1.7
|
||||||
peft>=0.18.1
|
peft>=0.18.1
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==5.0.0
|
transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
|
||||||
accelerate==1.12.0
|
accelerate==1.12.0
|
||||||
datasets==4.5.0
|
datasets==4.5.0
|
||||||
deepspeed>=0.18.3
|
deepspeed>=0.18.3
|
||||||
trl==0.27.1
|
trl==0.28.0
|
||||||
hf_xet==1.2.0
|
hf_xet==1.2.0
|
||||||
kernels==0.11.5
|
kernels==0.11.5
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.13.0
|
torchao==0.16.0
|
||||||
openenv-core==0.1.0
|
openenv-core==0.1.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
ddp_find_unused_parameters
|
ddp_find_unused_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
if self.cfg.group_by_length:
|
||||||
|
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
|
|||||||
)
|
)
|
||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.callbacks.qat import QATCallback
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
@@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|
||||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||||
)
|
)
|
||||||
@@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.cpo_alpha is not None:
|
if self.cfg.cpo_alpha is not None:
|
||||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
# Handle when max_prompt_length == max_length from defaults
|
blocklist_args_kwargs.append("max_prompt_length")
|
||||||
# CPOTrainer requires strictly less than
|
|
||||||
if (
|
|
||||||
training_args_kwargs["max_prompt_length"]
|
|
||||||
== training_args_kwargs["max_length"]
|
|
||||||
):
|
|
||||||
training_args_kwargs["max_prompt_length"] -= 1
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
|
|
||||||
|
blocklist_args_kwargs.append("max_prompt_length")
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.KTO:
|
elif self.cfg.rl is RLType.KTO:
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
||||||
blocklist_args_kwargs = ["max_prompt_length"]
|
blocklist_args_kwargs.append("max_prompt_length")
|
||||||
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
self.cfg.kto_desirable_weight or 1.0
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
@@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
|
||||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
|
|||||||
@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
|
|||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length: int | None = None,
|
||||||
max_completion_length,
|
max_completion_length: int | None = None,
|
||||||
add_special_tokens,
|
add_special_tokens: bool = True,
|
||||||
|
is_chat: bool = False,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = DPOTrainer.tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length=max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length=max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
|
is_chat=is_chat,
|
||||||
)
|
)
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
|
|||||||
|
|
||||||
return optimizer_grouped_parameters
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self, model=None):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
|
|||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.optimizer_cls_and_kwargs is None
|
and self.optimizer_cls_and_kwargs is None
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(model=model)
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model if model is None else model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.optimizer
|
not self.optimizer
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from functools import cached_property
|
|||||||
import addict
|
import addict
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
@@ -500,6 +501,7 @@ class PatchManager:
|
|||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
and not self.cfg.gptq
|
and not self.cfg.gptq
|
||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
|
and is_flash_attn_available()
|
||||||
and not self.inference
|
and not self.inference
|
||||||
):
|
):
|
||||||
# TODO(MengqingCao): split these patches separately
|
# TODO(MengqingCao): split these patches separately
|
||||||
|
|||||||
@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
|||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
hidden_states.requires_grad = True
|
hidden_states.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
|
||||||
|
# return a plain tensor, not a tuple. Older models return tuples
|
||||||
|
# like (hidden_states, present_kv, ...). Unwrap if needed.
|
||||||
|
if isinstance(output, (tuple, list)):
|
||||||
|
(output,) = output
|
||||||
torch.autograd.backward(output, dY)
|
torch.autograd.backward(output, dY)
|
||||||
return (
|
return (
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
|
|||||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||||
}
|
}
|
||||||
|
|
||||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
ORIGINAL_MAYBE_CODE = (
|
||||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
|
||||||
|
)
|
||||||
|
PATCHED_MAYBE_CODE = (
|
||||||
|
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_evaluation_loop_is_patchable() -> bool:
|
def check_evaluation_loop_is_patchable() -> bool:
|
||||||
|
|||||||
@@ -446,7 +446,16 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "List of regex patterns for parameter names to keep unfrozen. "
|
||||||
|
"All other parameters will be frozen via requires_grad=False. "
|
||||||
|
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
|
||||||
|
"zeroing rather than a true freeze, so weight decay will still apply to the "
|
||||||
|
"frozen portion and optimizer states are allocated for the full parameter."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
sequence_len: int = Field(
|
sequence_len: int = Field(
|
||||||
default=512,
|
default=512,
|
||||||
|
|||||||
@@ -300,7 +300,6 @@ class TestHFRLTrainerBuilder:
|
|||||||
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
|
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
|
||||||
# ORPO specific
|
# ORPO specific
|
||||||
assert training_arguments.beta == 0.1 # maps from orpo_alpha
|
assert training_arguments.beta == 0.1 # maps from orpo_alpha
|
||||||
assert training_arguments.max_prompt_length == 512
|
|
||||||
|
|
||||||
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
|
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ class TestFSDP1:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
|
||||||
def test_dpo_fft(self, temp_dir):
|
def test_dpo_fft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -365,6 +365,7 @@ class TestFSDP2:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
def test_dpo_fft(self, temp_dir):
|
def test_dpo_fft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -115,6 +115,9 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
|
|
||||||
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
||||||
LOG.info("Testing phi-3.5 with assistant dataset")
|
LOG.info("Testing phi-3.5 with assistant dataset")
|
||||||
|
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
|
||||||
|
"phi35 tokenizer should be a LlamaTokenizer"
|
||||||
|
)
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
phi35_tokenizer,
|
phi35_tokenizer,
|
||||||
@@ -140,13 +143,13 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
expected_input_ids = [
|
expected_input_ids = [
|
||||||
32010, # user
|
32010, # user
|
||||||
22172, 32007, # user eot
|
12199, 32007, # user eot
|
||||||
32001, # assistant
|
32001, # assistant
|
||||||
22172, 32007, # assistant eot
|
12199, 32007, # assistant eot
|
||||||
32010, # user
|
32010, # user
|
||||||
1781, 26966, 32007, # user eot
|
16773, 26966, 32007, # user eot
|
||||||
32001, # assistant
|
32001, # assistant
|
||||||
1781, 26966, 32007, # assistant eot
|
16773, 26966, 32007, # assistant eot
|
||||||
]
|
]
|
||||||
expected_labels = [
|
expected_labels = [
|
||||||
-100, # user
|
-100, # user
|
||||||
@@ -156,7 +159,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
-100, # user
|
-100, # user
|
||||||
-100, -100, -100, # user eot
|
-100, -100, -100, # user eot
|
||||||
-100, # assistant
|
-100, # assistant
|
||||||
1781, 26966, 32007, # assistant eot
|
16773, 26966, 32007, # assistant eot
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ class TestTokenizers:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
assert "LlamaTokenizer" in tokenizer.__class__.__name__
|
||||||
|
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
|
||||||
assert len(tokenizer) == 32001
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||||
|
|||||||
Reference in New Issue
Block a user