From 9394d17f28e32e3430624d714fb3d1bcaaa51eeb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Mar 2026 21:22:35 -0400 Subject: [PATCH] fix liger kernel setup --- src/axolotl/core/trainers/grpo/__init__.py | 2 +- src/axolotl/core/trainers/grpo/fast_async_trainer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 1fa0899b7..c4da5ee9f 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -137,7 +137,7 @@ class GRPOStrategy: grpo_args_kwargs["epsilon_high"] = trl.epsilon_high if trl.use_liger_loss is not None: - grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss + grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss if trl.multi_objective_aggregation is not None: grpo_args_kwargs["multi_objective_aggregation"] = ( diff --git a/src/axolotl/core/trainers/grpo/fast_async_trainer.py b/src/axolotl/core/trainers/grpo/fast_async_trainer.py index 76820d5f1..3367d05e8 100644 --- a/src/axolotl/core/trainers/grpo/fast_async_trainer.py +++ b/src/axolotl/core/trainers/grpo/fast_async_trainer.py @@ -30,6 +30,7 @@ from dataclasses import dataclass, field import torch from torch import nn +from trl import GRPOTrainer from axolotl.core.trainers.grpo.async_trainer import ( AsyncGRPOConfig, @@ -278,6 +279,7 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): shuffle_dataset=self.shuffle_dataset, seed=args.seed, ) + data_producer.set_trainer(self) if args.async_prefetch: data_producer = AsyncDataProducer( data_producer,