Compare commits
1 Commits
textui
...
rl-trainer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c25990fd4f |
@@ -156,6 +156,9 @@ class AxolotlTrainer(
|
|||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
and sample packing cases.
|
and sample packing cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: Evaluation dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
@@ -237,9 +240,6 @@ class AxolotlTrainer(
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
|
||||||
# slice each batch along the sequence dimension).
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -1,33 +1,25 @@
|
|||||||
"""
|
"""DPO trainer for Axolotl"""
|
||||||
DPO trainer for axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import random
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
from datasets import Dataset
|
||||||
from accelerate import PartialState
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import Sampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BaseImageProcessor,
|
|
||||||
FeatureExtractionMixin,
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
ProcessorMixin,
|
|
||||||
Trainer,
|
Trainer,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import EvalLoopOutput
|
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import log_table_to_comet_experiment
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import (
|
||||||
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
SequenceParallelMixin,
|
||||||
|
)
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -37,10 +29,10 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(
|
||||||
"""
|
RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, DPOTrainer
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
):
|
||||||
"""
|
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
@@ -95,64 +87,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
|
||||||
def _prepare_dataset(
|
|
||||||
self,
|
|
||||||
dataset: Union[Dataset, IterableDataset],
|
|
||||||
processing_class: Union[
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
BaseImageProcessor,
|
|
||||||
FeatureExtractionMixin,
|
|
||||||
ProcessorMixin,
|
|
||||||
],
|
|
||||||
args: DPOConfig,
|
|
||||||
dataset_name: str,
|
|
||||||
) -> Union[Dataset, IterableDataset]:
|
|
||||||
# Build the kwargs for the `map` function
|
|
||||||
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
|
|
||||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
|
||||||
map_kwargs["num_proc"] = args.dataset_num_proc
|
|
||||||
|
|
||||||
with PartialState().main_process_first():
|
|
||||||
# Extract prompt if needed
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
|
||||||
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
|
||||||
|
|
||||||
# Apply the chat template if needed
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
|
||||||
dataset = dataset.map(
|
|
||||||
maybe_apply_chat_template,
|
|
||||||
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
|
|
||||||
**map_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize the dataset
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
self.tokenize_row if not self.is_vision_model else self.process_row,
|
|
||||||
remove_columns=["chosen", "rejected"],
|
|
||||||
fn_kwargs={
|
|
||||||
"processing_class": processing_class,
|
|
||||||
"max_prompt_length": args.max_prompt_length,
|
|
||||||
"max_completion_length": args.max_completion_length,
|
|
||||||
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
|
||||||
"add_special_tokens": False,
|
|
||||||
},
|
|
||||||
**map_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
@@ -193,68 +127,48 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
def evaluation_loop(
|
|
||||||
self,
|
|
||||||
dataloader: DataLoader,
|
|
||||||
description: str,
|
|
||||||
prediction_loss_only: Optional[bool] = None,
|
|
||||||
ignore_keys: Optional[list[str]] = None,
|
|
||||||
metric_key_prefix: str = "eval",
|
|
||||||
) -> EvalLoopOutput:
|
|
||||||
"""
|
"""
|
||||||
Overriding built-in evaluation loop to store metrics for each batch.
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
parallelism, sample packing, and curriculum sampling (sequential).
|
||||||
|
|
||||||
Works both with or without labels.
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
# Sample and save to game log if requested (for one batch to save time)
|
if dist.get_rank() == 0:
|
||||||
if self.generate_during_eval:
|
import ipdb
|
||||||
# Generate random indices within the range of the total number of samples
|
|
||||||
num_samples = len(dataloader.dataset)
|
|
||||||
random_indices = random.sample(
|
|
||||||
range(num_samples), k=self.args.eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
ipdb.set_trace()
|
||||||
random_batch_dataset = dataloader.dataset.select(random_indices)
|
dist.barrier()
|
||||||
random_batch = self.data_collator(random_batch_dataset)
|
if dist.get_rank() == 1:
|
||||||
random_batch = self._prepare_inputs(random_batch)
|
import ipdb
|
||||||
|
|
||||||
policy_output_decoded, ref_output_decoded = (
|
ipdb.set_trace()
|
||||||
self.generate_from_model_and_ref(self.model, random_batch)
|
dist.barrier()
|
||||||
)
|
|
||||||
|
|
||||||
table = pd.DataFrame(
|
if self.args.sequence_parallel_degree > 1:
|
||||||
columns=["Prompt", "Policy", "Ref Model"],
|
return self._sp_get_train_sampler(self.train_dataset)
|
||||||
data=[
|
|
||||||
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
|
||||||
for prompt, pol, ref in zip(
|
|
||||||
random_batch_dataset["prompt"],
|
|
||||||
policy_output_decoded,
|
|
||||||
ref_output_decoded,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
|
||||||
wandb.log({"game_log": wandb.Table(data=table)})
|
|
||||||
|
|
||||||
if "comet_ml" in self.args.report_to:
|
return super()._get_train_sampler()
|
||||||
log_table_to_comet_experiment(
|
|
||||||
name="game_log.csv",
|
|
||||||
table=table,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base evaluation
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
initial_output = super( # pylint: disable=bad-super-call
|
"""
|
||||||
DPOTrainer, self
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
).evaluation_loop(
|
and sample packing cases.
|
||||||
dataloader,
|
|
||||||
description,
|
|
||||||
prediction_loss_only,
|
|
||||||
ignore_keys,
|
|
||||||
metric_key_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
return initial_output
|
Args:
|
||||||
|
eval_dataset: Evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
return self._sp_get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|||||||
@@ -266,9 +266,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
|
||||||
# slice each batch along the sequence dimension).
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -32,7 +33,7 @@ def apply_sequence_parallelism(
|
|||||||
to only keep the last N tokens in the sequence during generation.
|
to only keep the last N tokens in the sequence during generation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
batch: Dictionary of model arguments (e.g., input_ids, attention_mask, etc.).
|
||||||
local_rank: Local rank in the sequence parallel group.
|
local_rank: Local rank in the sequence parallel group.
|
||||||
local_world_size: World size of the sequence parallel group.
|
local_world_size: World size of the sequence parallel group.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
@@ -206,12 +207,26 @@ class SequenceParallelContextManager:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
# Apply sequence parallelism to kwargs and get original sequence length and padding info
|
# Convert all args to kwargs using the model's forward function signature
|
||||||
kwargs, self.original_seq_len, self.pad_len = (
|
updated_kwargs = kwargs.copy()
|
||||||
self.apply_sequence_parallelism(batch=kwargs)
|
|
||||||
|
# Get parameter names from the model's forward function
|
||||||
|
forward_params = list(
|
||||||
|
inspect.signature(self.models[0].forward).parameters.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
return args, kwargs
|
# Map args to their parameter names
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i < len(forward_params):
|
||||||
|
param_name = forward_params[i]
|
||||||
|
updated_kwargs[param_name] = arg
|
||||||
|
|
||||||
|
# Apply sequence parallelism to empty args and updated kwargs
|
||||||
|
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
self.apply_sequence_parallelism(updated_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return (), updated_kwargs
|
||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||||
|
|||||||
Reference in New Issue
Block a user