Compare commits
1 Commits
update-vll
...
fix/rl-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d47093fcdd |
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ liger-kernel==0.5.10
|
|||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub[hf_xet]==0.33.0
|
huggingface_hub==0.32.2
|
||||||
peft==0.15.2
|
peft==0.15.2
|
||||||
transformers==4.53.1
|
transformers==4.53.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
|
|||||||
6
setup.py
6
setup.py
@@ -65,13 +65,15 @@ def parse_requirements(extras_require_map):
|
|||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 7):
|
if (major, minor) >= (2, 7):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append(
|
_install_requires.append(
|
||||||
"xformers==0.0.29.post2"
|
"xformers==0.0.29.post2"
|
||||||
) # vllm needs post2 w torch 2.6
|
) # vllm needs post2 w torch 2.6
|
||||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from axolotl.core.builders.base import TrainerBuilderBase
|
from axolotl.core.builders.base import TrainerBuilderBase
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlDPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
)
|
)
|
||||||
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
def _get_trainer_cls(self):
|
||||||
"""
|
"""Returns trainer_cls"""
|
||||||
Returns trainer_cls and trainer_cls_args
|
|
||||||
"""
|
|
||||||
if self.cfg.plugins:
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
trainer_cls_args = [] # type: ignore
|
|
||||||
|
|
||||||
if trainer_cls is not None:
|
if trainer_cls is not None:
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls
|
||||||
|
|
||||||
trainer_cls = None
|
trainer_cls = None
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
|
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
|
||||||
|
|
||||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
|
||||||
|
|
||||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
trainer_cls = DPOStrategy.get_trainer_class()
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args.append(self.model_ref)
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
elif self.cfg.rl is RLType.KTO:
|
elif self.cfg.rl is RLType.KTO:
|
||||||
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls
|
||||||
|
|
||||||
def _build_training_arguments(self, total_num_steps):
|
def _build_training_arguments(self, total_num_steps):
|
||||||
"""
|
"""
|
||||||
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
|
||||||
|
if self.cfg.rl is RLType.GRPO:
|
||||||
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
|
||||||
|
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
|
trainer_cls_args.append(self.model_ref)
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (
|
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||||
trainer_cls is DPOStrategy.get_trainer_class()
|
|
||||||
):
|
|
||||||
trainer_kwargs["dataset_tags"] = [
|
trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user