diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2b467f2b7..411496ac5 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -36,7 +36,7 @@ from transformers import ( from transformers.training_args import OptimizerNames from trl.trainer.utils import RewardDataCollatorWithPadding -from axolotl.core.trainers.base import ( +from axolotl.core.trainers import ( AxolotlCPOTrainer, AxolotlKTOTrainer, AxolotlMambaTrainer, diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index e69de29bb..1080c5f6c 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -0,0 +1,17 @@ +"""Init for axolotl.core.trainers""" +# pylint: disable=unused-import +# flake8: noqa + +from .base import AxolotlTrainer +from .dpo.trainer import AxolotlDPOTrainer +from .grpo.trainer import AxolotlGRPOTrainer +from .mamba import AxolotlMambaTrainer +from .relora import ReLoRATrainer +from .trl import ( + AxolotlCPOTrainer, + AxolotlKTOTrainer, + AxolotlORPOTrainer, + AxolotlPRMTrainer, + AxolotlRewardTrainer, + TRLPPOTrainer, +) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index cab8a0634..0aae48300 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -11,366 +11,35 @@ from typing import Any, Literal import datasets import torch -import torch.distributed as dist -import torch.nn.functional as F from datasets import Dataset -from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import ( + BatchSampler, + DataLoader, + RandomSampler, + Sampler, + SequentialSampler, +) from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker -from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled -from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length from typing_extensions import override -from axolotl.integrations.base import BaseOptimizerFactory -from axolotl.monkeypatch.relora import ReLoRAScheduler -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.schedulers import ( - RexLR, - get_cosine_schedule_with_min_lr, - get_cosine_schedule_with_quadratic_warmup, - get_cosine_schedule_with_warmup_decay_constant, +from axolotl.core.trainers.mixins import ( + OptimizerMixin, + SchedulerMixin, + SequenceParallelMixin, ) - -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp - -try: - from ring_flash_attn import update_ring_flash_attn_params -except ImportError: - # pylint: disable=unused-argument - def update_ring_flash_attn_params(*args, **kwargs): - raise ImportError( - "ring_flash_attn is not installed. " - "Please install it with `pip install axolotl[ring-flash-attn] " - "or `pip install ring-flash-attn>=0.1.4`." - ) - +from axolotl.core.trainers.utils import ( + sanitize_kwargs_for_ds_tagging, + sanitize_kwargs_for_tagging, +) +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - - -def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): - if isinstance(dataset_tags, str): - dataset_tags = [dataset_tags] - - if (dataset_tags is not None) and (kwargs is not None): - if "dataset_tags" not in kwargs: - kwargs["dataset_tags"] = dataset_tags - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): - kwargs["dataset_tags"].extend(dataset_tags) - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): - dataset_tags.append(kwargs["dataset_tags"]) - kwargs["dataset_tags"] = dataset_tags - - return kwargs - - -class SchedulerMixin(Trainer): - """ - Mixin class for scheduler setup in CausalTrainer. - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if self.args.alternate_lr_scheduler_type == "one_cycle": - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - extra_lr_kwargs = {} - if "pct_start" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["pct_start"] = pct_start - if "anneal_strategy" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["anneal_strategy"] = "cos" - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - **extra_lr_kwargs, - **self.args.lr_scheduler_kwargs, - ) - elif self.args.alternate_lr_scheduler_type == "rex": - if use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - - self.lr_scheduler = RexLR( - optimizer=optimizer, - max_lr=self.args.learning_rate, - min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), - total_steps=num_training_steps, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - ) - elif use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - - -class OptimizerMixin(Trainer): - """ - Mixin class for shared handling of building custom optimizers - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def create_optimizer_grouped_parameters( - self, opt_model, optimizer_kwargs - ) -> list[dict]: - decay_parameters = self.get_decay_parameter_names(opt_model) - params: dict = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - lr_groups_lookup = {} - lr_groups_learning_rates = {} - if self.args.lr_groups: - for lr_group in self.args.lr_groups: - group_name = lr_group["name"] - group_modules = lr_group["modules"] - for module in group_modules: - lr_groups_lookup[module] = group_name - lr_groups_learning_rates[group_name] = lr_group["lr"] - params[f"to_weight_decay_{group_name}"] = {} - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - lr_group_modules = [ - group_modules - for group_modules in lr_groups_lookup - if group_modules in name - ] - if lr_groups_lookup and any(lr_group_modules): - lr_group_module = lr_group_modules[0] - group_name = lr_groups_lookup[lr_group_module] - params[f"to_weight_decay_{group_name}"][name] = param - else: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) - for group_name, group_lr in lr_groups_learning_rates.items(): - if params[f"to_weight_decay_{group_name}"]: - optimizer_grouped_parameters.append( - { - "params": list( - params[f"to_weight_decay_{group_name}"].values() - ), - "weight_decay": self.args.weight_decay, - "lr": group_lr, - } - ) - - return optimizer_grouped_parameters - - def create_optimizer(self): - if ( - self.args.loraplus_lr_ratio is None - and self.args.embedding_lr_scale is None - and self.args.embedding_lr is None - and self.args.lr_groups is None - and self.optimizer_cls_and_kwargs is None - ): - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - - if ( - not self.optimizer - and self.optimizer_cls_and_kwargs is not None - and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) - ): - optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs - self.optimizer = optimizer_factory_cls()( - opt_model, self.args, **optimizer_kwargs - ) - - if not self.optimizer: - if self.optimizer_cls_and_kwargs is not None: - optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs - else: - optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs( - self.args, opt_model - ) - - optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( - opt_model, optimizer_kwargs - ) - - if self.args.loraplus_lr_ratio is not None: - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", 1e-6 - ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - else: - # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` - # e.g. for GaLore optimizer. - if "params" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop("params") - - # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` - # e.g. for LOMO optimizer. - if "model" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop("model") - - # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` - # to avoid arguments conflicts. - if "optimizer_dict" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop( - "optimizer_dict" - ) - - self.optimizer = optimizer_cls( - optimizer_grouped_parameters, **optimizer_kwargs - ) - - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum( - { - p.data_ptr(): p.numel() for p in module.parameters() - }.values() - ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") - manager.register_module_override( - module, "weight", {"optim_bits": 32} - ) - LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - -class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -396,10 +65,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + # Initialize sequence parallelism if enabled if self.args.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - - self.ring_attn_group = get_ring_attn_group() + self._setup_sequence_parallel() def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: @@ -413,8 +81,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) - def _create_multipack_sampler(self, base_sampler, dataset, group_size): - """Helper method to create a MultipackBatchSampler""" + def _create_multipack_sampler( + self, base_sampler: Sampler, dataset: Dataset + ) -> MultipackBatchSampler: + """ + Helper method to create a `MultipackBatchSampler` for multipacking sequences + for training. + + Args: + base_sampler: Sampler to wrap with `MultipackBatchSampler`. + dataset: Dataset to sample from. + + Returns: + Multipack (sample packing) batch sampler. + """ if self.args.multipack_real_batches: batch_size = self.args.per_device_train_batch_size batch_max_len = self.args.max_seq_length @@ -431,97 +111,74 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): packing_efficiency_estimate=self.args.sample_packing_efficiency, batch_max_len=batch_max_len, batch_size=batch_size, - group_size=group_size, - bin_size=self.args.sample_packing_bin_size, drop_last=True, ) - def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False): - """Create a sampler for sequence parallelism""" - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree + @override + def _get_train_sampler(self) -> Sampler | None: + """ + Helper method to get the sampler for training. Handles cases for sequence + parallelism, sample packing, and curriculum sampling (sequential). - return torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed if shuffle else None, - shuffle=shuffle, - drop_last=not is_eval, - ) + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. + """ + use_sample_packing = self.args.sample_packing and not self.args.pretraining - def _get_train_sampler(self) -> torch.utils.data.Sampler | None: - # Handle sequence parallelism + # Determine the base sampler first if self.args.sequence_parallel_degree > 1: - base_sampler = self._create_sp_sampler( - self.train_dataset, shuffle=not self.args.curriculum_sampling - ) + base_sampler = self._sp_get_train_sampler(self.train_dataset) + elif self.args.curriculum_sampling: + base_sampler = SequentialSampler(self.train_dataset) + elif use_sample_packing: + base_sampler = RandomSampler(self.train_dataset) + else: + # Default to parent class implementation for standard random sampling + return super()._get_train_sampler() - # Apply multipack wrapper if needed - if self.args.sample_packing and not self.args.pretraining: - return self._create_multipack_sampler( - base_sampler=base_sampler, - dataset=self.train_dataset, - group_size=self.args.sample_packing_group_size, - ) - return base_sampler - - # Regular training sampler logic - if self.args.sample_packing and not self.args.pretraining: - base_sampler = ( - SequentialSampler(self.train_dataset) - if self.args.curriculum_sampling - else RandomSampler(self.train_dataset) - ) + # Apply multipack wrapper if needed + if use_sample_packing: return self._create_multipack_sampler( base_sampler=base_sampler, dataset=self.train_dataset, - group_size=self.args.sample_packing_group_size, ) - if self.args.curriculum_sampling: - return SequentialSampler(self.train_dataset) + return base_sampler - return super()._get_train_sampler() + @override + def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: + """ + Helper method to get the sampler for evaluation. Handles sequence parallelism + and sample packing cases. - def _get_eval_sampler( - self, eval_dataset: Dataset | None = None - ) -> torch.utils.data.Sampler | None: - """Get evaluation sampler""" + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. + """ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - # Get the appropriate group size for sample packing - def get_pack_group_size(): - return ( - self.args.eval_packing_group_size - if hasattr(self.args, "eval_packing_group_size") - else self.args.sample_packing_group_size - ) + # Multipacking enabled if training is enabled and eval is not explicitly disabled + use_multipack = ( + self.args.sample_packing and self.args.eval_sample_packing is not False + ) - # Handle sequence parallelism + # Determine the base sampler if self.args.sequence_parallel_degree > 1: - base_sampler = self._create_sp_sampler( - eval_dataset, shuffle=False, is_eval=True - ) - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - return self._create_multipack_sampler( - base_sampler=base_sampler, - dataset=eval_dataset, - group_size=get_pack_group_size(), - ) - return base_sampler - - # Regular evaluation sampler logic - if self.args.sample_packing and self.args.eval_sample_packing is not False: + base_sampler = self._sp_get_eval_sampler(eval_dataset) + elif use_multipack: base_sampler = SequentialSampler(eval_dataset) + else: + return super()._get_eval_sampler(eval_dataset) + + # Apply multipack wrapper if needed + if use_multipack: return self._create_multipack_sampler( base_sampler=base_sampler, dataset=eval_dataset, - group_size=get_pack_group_size(), ) - return super()._get_eval_sampler(eval_dataset) + return base_sampler def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): """Create common dataloader parameters for train or eval.""" @@ -588,7 +245,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): data_collator = self.data_collator # type: ignore # Handle dataset preprocessing - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + if isinstance(train_dataset, datasets.Dataset): if self.args.sample_packing and not self.args.pretraining: train_dataset = train_dataset.remove_columns(["length"]) if not self.args.sample_packing or self.args.pretraining: @@ -640,9 +297,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): # Handle dataset preprocessing for SP if self.args.sequence_parallel_degree > 1: - if is_datasets_available() and isinstance( - eval_dataset, datasets.Dataset - ): + if isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns( eval_dataset, description="evaluation" ) @@ -885,10 +540,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = _sanitize_kwargs_for_ds_tagging( + kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -944,148 +599,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): num_items_in_batch: int | None = None, ) -> torch.Tensor: """ - Perform a training step on a batch of inputs. + Perform a training step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. + + Args: + model: Model to perform training step for. + inputs: Dictionary mapping. """ - if self.args.sequence_parallel_degree > 1: - # At this point, inputs should already be partitioned by the sequence - # parallel data collator - batch_size = inputs["input_ids"].shape[0] - seq_len = inputs["input_ids"].shape[1] - - # Calculate the full sequence length across all GPUs in this SP group - total_seq_len = seq_len * self.args.sequence_parallel_degree - - # Pass the partitioned sequence information to ring flash attention - self._update_ring_flash_attn_params( - packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len - ) + # Set up sequence parallelism for this step if enabled + self._sp_training_step_setup(inputs) + # Proceed with normal training step loss = super().training_step(model, inputs, num_items_in_batch) return loss - - def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): - """ - Calculate the cu_seqlens for the current forward pass and pass the value to - the substituted ring_flash_attn. - """ - cu_seqlens = torch.cumsum( - torch.tensor( - packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 - ), - dim=-1, - dtype=torch.int32, - ) - cu_seqlens = F.pad( - F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len - ) - - update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) - - -class AxolotlMambaTrainer(AxolotlTrainer): - """ - Mamba specific trainer to handle loss calculation - """ - - tag_names = ["axolotl", "mamba"] - - def compute_loss( - self, - model, - inputs, - return_outputs=False, # pylint: disable=unused-argument - num_items_in_batch=None, # pylint: disable=unused-argument - ): - input_ids = inputs.pop("input_ids") - lm_logits = model(input_ids).logits - - labels = input_ids.to(lm_logits.device) - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - return lm_loss - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: torch.optim.Optimizer | None = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - -class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): - """ - Extend the base ORPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "orpo"] - - -class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): - """ - Extend the base KTOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "kto"] - - -class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): - """ - Extend the base CPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "cpo"] - - -class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): - """ - Extend the base RewardTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "reward"] - - -class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): - """ - Extend the base trl.PRMTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "prm"] diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 38b657260..9eb870a3a 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -13,10 +13,10 @@ from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer -from axolotl.core.trainers.base import ( - SchedulerMixin, - _sanitize_kwargs_for_ds_tagging, - _sanitize_kwargs_for_tagging, +from axolotl.core.trainers.mixins import SchedulerMixin +from axolotl.core.trainers.utils import ( + sanitize_kwargs_for_ds_tagging, + sanitize_kwargs_for_tagging, ) if is_sagemaker_mp_enabled(): @@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = _sanitize_kwargs_for_ds_tagging( + kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) diff --git a/src/axolotl/core/trainers/mamba.py b/src/axolotl/core/trainers/mamba.py new file mode 100644 index 000000000..38792e389 --- /dev/null +++ b/src/axolotl/core/trainers/mamba.py @@ -0,0 +1,32 @@ +"""Module for mamba trainer""" + +import torch + +from axolotl.core.trainers.base import AxolotlTrainer + + +class AxolotlMambaTrainer(AxolotlTrainer): + """Mamba specific trainer to handle loss calculation""" + + tag_names = ["axolotl", "mamba"] + + def compute_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py new file mode 100644 index 000000000..a4b8fb1e2 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -0,0 +1,7 @@ +"""Init for axolotl.core.trainers.mixins""" +# pylint: disable=unused-import +# flake8: noqa + +from .optimizer import OptimizerMixin +from .scheduler import SchedulerMixin +from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py new file mode 100644 index 000000000..bde58aa1d --- /dev/null +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -0,0 +1,201 @@ +"""Module for Axolotl trainer optimizer mixin""" + +import logging + +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from transformers.trainer import Trainer +from transformers.utils import is_sagemaker_mp_enabled + +from axolotl.integrations.base import BaseOptimizerFactory + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + +LOG = logging.getLogger(__name__) + + +class OptimizerMixin(Trainer): + """Mixin class for shared handling of building custom optimizers""" + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def create_optimizer_grouped_parameters( + self, opt_model, optimizer_kwargs + ) -> list[dict]: + decay_parameters = self.get_decay_parameter_names(opt_model) + params: dict = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + + def create_optimizer(self): + if ( + self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None + and self.args.lr_groups is None + and self.optimizer_cls_and_kwargs is None + ): + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if ( + not self.optimizer + and self.optimizer_cls_and_kwargs is not None + and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) + ): + optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + self.optimizer = optimizer_factory_cls()( + opt_model, self.args, **optimizer_kwargs + ) + + if not self.optimizer: + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs( + self.args, opt_model + ) + + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) + + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", 1e-6 + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + else: + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "optimizer_dict" + ) + + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, **optimizer_kwargs + ) + + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum( + { + p.data_ptr(): p.numel() for p in module.parameters() + }.values() + ) + LOG.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override( + module, "weight", {"optim_bits": 32} + ) + LOG.debug(f"bitsandbytes: will optimize {module} in fp32") + LOG.info(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py new file mode 100644 index 000000000..b0a5ee895 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -0,0 +1,113 @@ +"""Module for Axolotl trainer scheduler mixin""" + +import logging + +import torch +from torch.optim.lr_scheduler import OneCycleLR +from transformers.trainer import Trainer + +from axolotl.utils.schedulers import ( + RexLR, + get_cosine_schedule_with_min_lr, + get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_warmup_decay_constant, +) + +LOG = logging.getLogger(__name__) + + +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif self.args.alternate_lr_scheduler_type == "rex": + if use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + + self.lr_scheduler = RexLR( + optimizer=optimizer, + max_lr=self.args.learning_rate, + min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), + total_steps=num_training_steps, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + ) + elif use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer=optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py new file mode 100644 index 000000000..3f511a7de --- /dev/null +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -0,0 +1,134 @@ +"""Module for Axolotl trainer sequence parallelism mixin""" + +import logging +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from datasets import Dataset +from torch.utils.data import DistributedSampler, Sampler + +from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + +LOG = logging.getLogger(__name__) + +try: + from ring_flash_attn import update_ring_flash_attn_params +except ImportError: + # We pass silently here, but raise an ImportError in our Axolotl config validation + # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. + pass + + +class SequenceParallelMixin: + """ + Mixin class for sequence parallelism support in trainers. + + This mixin provides functionality for handling sequence parallelism, + including creating appropriate samplers, managing data partitioning, + and updating ring flash attention parameters during training. + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def _setup_sequence_parallel(self): + """Set up sequence parallelism environment.""" + self.ring_attn_group = get_ring_attn_group() + + def _create_sequence_parallel_sampler( + self, + dataset: Dataset, + shuffle: bool = True, + is_eval: bool = False, + ) -> DistributedSampler: + """ + Helper method to create sampler for sequence parallelism (SP). + + We create a distributed sampler with rank equal to the SP group ID, which + means that all ranks in the SP group receive the same sample / set of samples + per training step. We also set the number of replicas equal to the number of + SP groups, which is a bit of a hack / unintended use, but works! + + Args: + dataset: Dataset to sample from. + shuffle: Whether to shuffle the dataset. + is_eval: Whether we are creating a sampler for evaluation or training. + + Returns: + Distributed sampler. + """ + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree + + return DistributedSampler( + dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed if shuffle else None, + shuffle=shuffle, + drop_last=not is_eval, + ) + + def _sp_get_train_sampler(self, dataset) -> Sampler | None: + """ + Get a training sampler configured for sequence parallelism. + + Args: + dataset: The training dataset + + Returns: + Configured sequence parallel sampler. + """ + return self._create_sequence_parallel_sampler( + dataset, + shuffle=not getattr(self.args, "curriculum_sampling", False), + ) + + def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: + """ + Get an evaluation sampler configured for sequence parallelism. + + Args: + eval_dataset: The evaluation dataset. + + Returns: + Configured sequence parallel sampler. + """ + return self._create_sequence_parallel_sampler( + eval_dataset, shuffle=False, is_eval=True + ) + + def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]): + """ + Calculate the cu_seqlens for the current forward pass and pass the value to + the substituted ring_flash_attn. This is accomplished by using the passed + `input_ids`. + + Args: + inputs: Current batch of inputs. + """ + if not self.args.sequence_parallel_degree > 1: + return + + # At this point, inputs should already be partitioned by the sequence + # parallel data collator + batch_size = inputs["input_ids"].shape[0] + seq_len = inputs["input_ids"].shape[1] + packed_seq_lens = [seq_len] * batch_size + + # Calculate the full sequence length across all GPUs in this SP group + total_seq_len = seq_len * self.args.sequence_parallel_degree + + cu_seqlens = torch.cumsum( + torch.tensor( + packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + ), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad( + F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + ) + + update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) diff --git a/src/axolotl/core/trainers/relora.py b/src/axolotl/core/trainers/relora.py new file mode 100644 index 000000000..3bcd4a9b8 --- /dev/null +++ b/src/axolotl/core/trainers/relora.py @@ -0,0 +1,43 @@ +"""Module for ReLoRA trainer""" + +import torch + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.monkeypatch.relora import ReLoRAScheduler + + +class ReLoRATrainer(AxolotlTrainer): + """Trainer subclass that uses the `OneCycleLR` scheduler""" + + tag_names = ["axolotl", "relora"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: torch.optim.Optimizer | None = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + anneal_steps = ( + self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + anneal_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index 7237e792e..1199313e8 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -1,16 +1,23 @@ -""" -module for TRL PPO training -""" +"""Module for TRL PPO trainer""" import torch from tqdm import tqdm -from trl import PPOTrainer +from trl import ( + CPOTrainer, + KTOTrainer, + ORPOTrainer, + PPOTrainer, + PRMTrainer, + RewardTrainer, +) + +from axolotl.core.trainers.mixins.scheduler import SchedulerMixin class TRLPPOTrainer(PPOTrainer): - """ - wrapper for ppo trainer to handle customizations - """ + """Wrapper for TRL PPO trainer to handle customizations""" + + tag_names = ["axolotl", "ppo"] def train( self, @@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer): "batch_size": 16, } - for epoch, batch in tqdm( # pylint: disable=unused-variable - enumerate(self.dataloader) - ): + for _, batch in tqdm(enumerate(self.dataloader)): query_tensors = batch["input_ids"] # generate model response @@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer): rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"], ) + + +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] + + +class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): + """ + Extend the base trl.PRMTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "prm"] diff --git a/src/axolotl/core/trainers/utils.py b/src/axolotl/core/trainers/utils.py new file mode 100644 index 000000000..c6d40cb61 --- /dev/null +++ b/src/axolotl/core/trainers/utils.py @@ -0,0 +1,33 @@ +"""Utils for Axolotl trainers""" + + +def sanitize_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + +def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags + + return kwargs diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 2cde5b98d..9ed332dfa 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -8,10 +8,16 @@ their sequence parallel version of Flash Attention 2. import torch.distributed as dist from accelerate.logging import get_logger -from ring_flash_attn import substitute_hf_flash_attn from axolotl.logging_config import configure_logging +try: + from ring_flash_attn import substitute_hf_flash_attn +except ImportError: + # We pass silently here, but raise an ImportError in our Axolotl config validation + # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. + pass + configure_logging() LOG = get_logger(__name__) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7f0c3c58c..cc130c2c6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1116,6 +1116,15 @@ class AxolotlInputConfig( "flash_attention: true must be set with sequence_parallel_degree > 1" ) + try: + import ring_flash_attn # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception + return value