Compare commits
4 Commits
fix/rl-tra
...
update-vll
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe81d52882 | ||
|
|
1eaa4ed89d | ||
|
|
fe47392ed6 | ||
|
|
1032e22650 |
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
@@ -11,7 +11,7 @@ liger-kernel==0.5.10
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub==0.32.2
|
||||
huggingface_hub[hf_xet]==0.33.0
|
||||
peft==0.15.2
|
||||
transformers==4.53.1
|
||||
tokenizers>=0.21.1
|
||||
|
||||
6
setup.py
6
setup.py
@@ -65,15 +65,13 @@ def parse_requirements(extras_require_map):
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
)
|
||||
@@ -37,23 +36,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
"""Returns trainer_cls"""
|
||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||
"""
|
||||
Returns trainer_cls and trainer_cls_args
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls_args = [] # type: ignore
|
||||
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
@@ -63,7 +72,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
return trainer_cls
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
"""
|
||||
@@ -173,15 +182,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters:
|
||||
@@ -189,7 +190,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user