fix liger kernel setup

This commit is contained in:
Wing Lian
2026-03-09 21:22:35 -04:00
parent e380f6944d
commit 9394d17f28
2 changed files with 3 additions and 1 deletions

View File

@@ -137,7 +137,7 @@ class GRPOStrategy:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (

View File

@@ -30,6 +30,7 @@ from dataclasses import dataclass, field
import torch
from torch import nn
from trl import GRPOTrainer
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOConfig,
@@ -278,6 +279,7 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
shuffle_dataset=self.shuffle_dataset,
seed=args.seed,
)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,