fix liger kernel setup
This commit is contained in:
@@ -137,7 +137,7 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||||
|
|
||||||
if trl.use_liger_loss is not None:
|
if trl.use_liger_loss is not None:
|
||||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
|
||||||
|
|
||||||
if trl.multi_objective_aggregation is not None:
|
if trl.multi_objective_aggregation is not None:
|
||||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trl import GRPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
AsyncGRPOConfig,
|
AsyncGRPOConfig,
|
||||||
@@ -278,6 +279,7 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
shuffle_dataset=self.shuffle_dataset,
|
shuffle_dataset=self.shuffle_dataset,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
data_producer.set_trainer(self)
|
||||||
if args.async_prefetch:
|
if args.async_prefetch:
|
||||||
data_producer = AsyncDataProducer(
|
data_producer = AsyncDataProducer(
|
||||||
data_producer,
|
data_producer,
|
||||||
|
|||||||
Reference in New Issue
Block a user