Compare commits

..

4 Commits

Author SHA1 Message Date
Dan Saunders
fe81d52882 remove xformers changes 2025-07-08 14:43:23 +00:00
Dan Saunders
1eaa4ed89d update other deps 2025-07-08 14:29:36 +00:00
Dan Saunders
fe47392ed6 updating vllm to latest 2025-07-08 14:24:22 +00:00
float-trip
1032e22650 Fix link in FSDP + QLoRA docs. (#2879) [skip ci] 2025-07-08 09:19:09 -04:00
4 changed files with 23 additions and 22 deletions

View File

@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> See the [example config](#example-config) file in addition to reading these instructions. > See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file. 1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp). 2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`. 3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config ## Example Config

View File

@@ -11,7 +11,7 @@ liger-kernel==0.5.10
packaging==23.2 packaging==23.2
huggingface_hub==0.32.2 huggingface_hub[hf_xet]==0.33.0
peft==0.15.2 peft==0.15.2
transformers==4.53.1 transformers==4.53.1
tokenizers>=0.21.1 tokenizers>=0.21.1

View File

@@ -65,15 +65,13 @@ def parse_requirements(extras_require_map):
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 7): if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5.post1"] extras_require_map["vllm"] = ["vllm==0.9.2"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append( _install_requires.append(
"xformers==0.0.29.post2" "xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6 ) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.5.post1"] extras_require_map["vllm"] = ["vllm==0.9.2"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import ( from axolotl.core.trainers import (
AxolotlCPOTrainer, AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer, AxolotlKTOTrainer,
AxolotlORPOTrainer, AxolotlORPOTrainer,
) )
@@ -37,23 +36,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks return callbacks
def _get_trainer_cls(self): def _get_trainer_cls(self, trainer_kwargs: dict):
"""Returns trainer_cls""" """
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins: if self.cfg.plugins:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg) trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None: if trainer_cls is not None:
return trainer_cls return trainer_cls, trainer_cls_args
trainer_cls = None trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO: if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1 sequence_parallel=self.cfg.sequence_parallel_degree > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = AxolotlDPOTrainer trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
@@ -63,7 +72,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps): def _build_training_arguments(self, total_num_steps):
""" """
@@ -173,15 +182,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.precompute_ref_log_probs self.cfg.precompute_ref_log_probs
) )
trainer_cls = self._get_trainer_cls() trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls_args.append(self.model_ref)
sig = inspect.signature(trainer_cls) sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters: if "tokenizer" in sig.parameters:
@@ -189,7 +190,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
trainer_kwargs["processing_class"] = self.tokenizer trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [ trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
] ]