Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
d260eeb57d match protected method 2026-02-15 07:55:55 -05:00
Wing Lian
5a7f007d20 cleanup ao fp8 patching 2026-02-13 17:02:23 -05:00
23 changed files with 70 additions and 280 deletions

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options: Configuration options:
```yaml ```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate # List of tasks to evaluate
lm_eval_tasks: lm_eval_tasks:
- arc_challenge - arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results output_dir: # Directory to save evaluation results
``` ```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details. See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4 ### delinearize-llama4

View File

@@ -2,21 +2,21 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1 bitsandbytes==0.49.1
triton>=3.4.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.7.0 liger-kernel==0.6.4
# END section # END section
packaging==26.0 packaging==26.0
huggingface_hub>=1.1.7 huggingface_hub>=1.1.7
peft>=0.18.1 peft>=0.18.1
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only transformers==5.0.0
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.28.0 trl==0.27.1
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.11.5 kernels==0.11.5
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.16.0 torchao==0.13.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1

View File

@@ -258,6 +258,11 @@ class TrainerBuilderBase(abc.ABC):
bf16 = bf16 if bf16 is not None else False bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16 training_args_kwargs["bf16"] = bf16
if self.cfg.fp8:
training_args_kwargs["fp8"] = True
if self.cfg.fp8_enable_fsdp_float8_all_gather:
training_args_kwargs["enable_fsdp_float8_all_gather:"] = True
def _configure_scheduler(self, training_args_kwargs: dict): def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]: if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["lr_scheduler_type"] = "cosine"

View File

@@ -246,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters ddp_find_unused_parameters
) )
if self.cfg.group_by_length: training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
@@ -52,8 +53,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
@@ -134,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length") # Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length") blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -154,8 +157,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
) )
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -584,11 +584,9 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
def additional_accelerator_args( def build_fp8_accelerator_args(self) -> dict[str, Any]:
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs args = {}
) -> dict[str, Any]: if self.args.fp8:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig from torchao.float8 import Float8LinearConfig
@@ -596,15 +594,22 @@ class AxolotlTrainer(
# scaling strategy. See more details here: # scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8. # https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig( config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, enable_fsdp_float8_all_gather=self.args.enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, force_recompute_fp8_weight_in_bwd=self.args.enable_fsdp_float8_all_gather
is True,
) )
ret_kwargs["mixed_precision"] = "fp8" args["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore args["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs return args
def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
args = super().build_accelerator_args(**kwargs)
fp8_args = self.build_fp8_accelerator_args()
args.update(fp8_args)
return args
def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
""" """

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row( def tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length: int | None = None, max_prompt_length,
max_completion_length: int | None = None, max_completion_length,
add_special_tokens: bool = True, add_special_tokens,
is_chat: bool = False,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length=max_prompt_length, max_prompt_length,
max_completion_length=max_completion_length, max_completion_length,
add_special_tokens=add_special_tokens, add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters return optimizer_grouped_parameters
def create_optimizer(self, model=None): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.optimizer_cls_and_kwargs is None
): ):
return super().create_optimizer(model=model) return super().create_optimizer()
opt_model = self.model if model is None else model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if ( if (
not self.optimizer not self.optimizer

View File

@@ -263,3 +263,13 @@ class AxolotlTrainingMixins:
dion_rank_multiple_of: int | None = field( dion_rank_multiple_of: int | None = field(
default=None, default=None,
) )
fp8: bool | None = field(
default=None,
metadata={"help": "Whether to use FP8 precision for training"},
)
enable_fsdp_float8_all_gather: bool | None = field(
default=None,
metadata={"help": "Whether to use FSDP with FP8 precision for all_gather"},
)

View File

@@ -1,44 +0,0 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -6,12 +6,6 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage ## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin - axolotl.integrations.lm_eval.LMEvalPlugin
@@ -22,50 +16,9 @@ lm_eval_tasks:
- arc_easy - arc_easy
lm_eval_batch_size: # Batch size for evaluation lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
``` ```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation ## Citation
```bib ```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs as LMEvalArgs from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=get_model_path(cfg), model=cfg.lm_eval_model or cfg.hub_model_id,
): ):
subprocess.run( # nosec subprocess.run( # nosec
lm_eval_args, lm_eval_args,

View File

@@ -13,21 +13,6 @@ import yaml
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command( def build_lm_eval_command(
tasks: list[str], tasks: list[str],
bfloat16=True, bfloat16=True,
@@ -123,7 +108,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=get_model_path(cfg), model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision, revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template, apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn, fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -10,7 +10,6 @@ from functools import cached_property
import addict import addict
import transformers import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -101,7 +100,6 @@ class PatchManager:
self._apply_fsdp_patches() self._apply_fsdp_patches()
self._apply_adapter_patches() self._apply_adapter_patches()
self._apply_model_specific_patches() self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches() self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches() self._apply_gradient_checkpointing_patches()
self._patch_attention() self._patch_attention()
@@ -236,17 +234,6 @@ class PatchManager:
patch_kimi_model() patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self): def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT.""" """Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter: if self.cfg.adapter:
@@ -501,7 +488,6 @@ class PatchManager:
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
and self.cfg.flash_attention and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference and not self.inference
): ):
# TODO(MengqingCao): split these patches separately # TODO(MengqingCao): split these patches separately

View File

@@ -59,12 +59,7 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True hidden_states.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args) (output,) = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY) torch.autograd.backward(output, dY)
return ( return (
None, None,

View File

@@ -1,83 +0,0 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = create_code
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals())
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = (
fixed_create_accelerator_and_postprocess
)

View File

@@ -28,12 +28,8 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
} }
ORIGINAL_MAYBE_CODE = ( ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()" PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool: def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -446,16 +446,7 @@ class AxolotlInputConfig(
}, },
) )
unfrozen_parameters: list[str] | None = Field( unfrozen_parameters: list[str] | None = None
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
sequence_len: int = Field( sequence_len: int = Field(
default=512, default=512,

View File

@@ -300,6 +300,7 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl) self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific # ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer): def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -186,7 +186,6 @@ class TestFSDP1:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -365,7 +365,6 @@ class TestFSDP2:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0 @require_torch_2_7_0
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -115,9 +115,6 @@ class TestAssistantChatTemplateLlama3:
def test_phi35(self, phi35_tokenizer, assistant_dataset): def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset") LOG.info("Testing phi-3.5 with assistant dataset")
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
"phi35 tokenizer should be a LlamaTokenizer"
)
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
@@ -143,13 +140,13 @@ class TestAssistantChatTemplateLlama3:
# fmt: off # fmt: off
expected_input_ids = [ expected_input_ids = [
32010, # user 32010, # user
12199, 32007, # user eot 22172, 32007, # user eot
32001, # assistant 32001, # assistant
12199, 32007, # assistant eot 22172, 32007, # assistant eot
32010, # user 32010, # user
16773, 26966, 32007, # user eot 1781, 26966, 32007, # user eot
32001, # assistant 32001, # assistant
16773, 26966, 32007, # assistant eot 1781, 26966, 32007, # assistant eot
] ]
expected_labels = [ expected_labels = [
-100, # user -100, # user
@@ -159,7 +156,7 @@ class TestAssistantChatTemplateLlama3:
-100, # user -100, # user
-100, -100, -100, # user eot -100, -100, -100, # user eot
-100, # assistant -100, # assistant
16773, 26966, 32007, # assistant eot 1781, 26966, 32007, # assistant eot
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Expected input_ids: {expected_input_ids}")

View File

@@ -84,8 +84,7 @@ class TestTokenizers:
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "LlamaTokenizer" in tokenizer.__class__.__name__ assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length # ensure reloading the tokenizer again from cfg results in same vocab length