upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0 * upgrade trl for parity * handle trl api changes * orpo doesn't have max_prompt_len to check anymore * cpoconfig doesn't take max_prompt_length and fix cpu offload * slow fsdp1 test * triton min 3.4.0 and liger to 0.7.0 * use transformers main for now for zero3 fix * handle group_by_length change * fix changes upstream * mark skip flaky test * use transformers latest release 5.2.0
This commit is contained in:
@@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
if self.cfg.group_by_length:
|
||||
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
|
||||
@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
@@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
@@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
||||
blocklist_args_kwargs = ["max_prompt_length"]
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
self.cfg.kto_desirable_weight or 1.0
|
||||
@@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
is_chat: bool = False,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
max_prompt_length=max_prompt_length,
|
||||
max_completion_length=max_completion_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_chat=is_chat,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
|
||||
@@ -10,6 +10,7 @@ from functools import cached_property
|
||||
import addict
|
||||
import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
@@ -500,6 +501,7 @@ class PatchManager:
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
and self.cfg.flash_attention
|
||||
and is_flash_attn_available()
|
||||
and not self.inference
|
||||
):
|
||||
# TODO(MengqingCao): split these patches separately
|
||||
|
||||
@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
|
||||
# return a plain tensor, not a tuple. Older models return tuples
|
||||
# like (hidden_states, present_kv, ...). Unwrap if needed.
|
||||
if isinstance(output, (tuple, list)):
|
||||
(output,) = output
|
||||
torch.autograd.backward(output, dY)
|
||||
return (
|
||||
None,
|
||||
|
||||
@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||
}
|
||||
|
||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||
ORIGINAL_MAYBE_CODE = (
|
||||
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
|
||||
)
|
||||
PATCHED_MAYBE_CODE = (
|
||||
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
|
||||
)
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
|
||||
Reference in New Issue
Block a user