fix liger kernel setup
This commit is contained in:
@@ -137,7 +137,7 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
|
||||
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
|
||||
@@ -30,6 +30,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from trl import GRPOTrainer
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOConfig,
|
||||
@@ -278,6 +279,7 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
shuffle_dataset=self.shuffle_dataset,
|
||||
seed=args.seed,
|
||||
)
|
||||
data_producer.set_trainer(self)
|
||||
if args.async_prefetch:
|
||||
data_producer = AsyncDataProducer(
|
||||
data_producer,
|
||||
|
||||
Reference in New Issue
Block a user