Compare commits
6 Commits
docker-bas
...
e2e-fsdp-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39ab9626f1 | ||
|
|
26bd81cec0 | ||
|
|
1302e31049 | ||
|
|
be5f554a62 | ||
|
|
22319182ab | ||
|
|
440aab8a6f |
@@ -2,6 +2,6 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
|
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
|||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.47.0
|
transformers>=4.46.3
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.45.0
|
bitsandbytes==0.45.0
|
||||||
accelerate==1.2.0
|
accelerate==1.2.0
|
||||||
@@ -31,7 +31,7 @@ art
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq==0.2.7.post2
|
autoawq==0.2.7.post3
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.4.2
|
liger-kernel==0.4.2
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
import axolotl
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
add_options_from_dataclass,
|
add_options_from_dataclass,
|
||||||
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from packaging import version
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
@@ -973,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super().log(logs, start_time)
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
try:
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
except TypeError:
|
||||||
|
return super().log(logs) # transformers<=4.46
|
||||||
|
return super().log(logs) # transformers<=4.46
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
@@ -1165,9 +1172,13 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
)
|
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
@@ -1185,9 +1196,13 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
)
|
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
@@ -1232,9 +1247,13 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
)
|
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
@@ -1252,9 +1271,13 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
)
|
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
@@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
logs, start_time
|
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
)
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
|
|||||||
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP optimizer save in trainer w 4.47.0
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||||
|
|
||||||
|
ORIGINAL_TRAINER_CODE = """
|
||||||
|
|
||||||
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_TRAINER_CODE = """
|
||||||
|
|
||||||
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
return ORIGINAL_TRAINER_CODE in training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_fsdp():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for fsdp with optimizer save
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
training_loop
|
||||||
|
)
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||||
|
return
|
||||||
|
|
||||||
|
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||||
|
training_loop = training_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -5,8 +5,7 @@ see https://github.com/huggingface/transformers/pull/35128
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM, Trainer
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
|||||||
@@ -380,6 +380,13 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.fsdp:
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
|
patch_training_loop_for_fsdp,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_training_loop_for_fsdp()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
@@ -406,10 +413,14 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
has_remote_code = (
|
if "auto_map" in self.model_config:
|
||||||
"auto_map" in self.model_config
|
try:
|
||||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
auto_map_config = self.model_config["auto_map"]
|
||||||
)
|
except TypeError:
|
||||||
|
auto_map_config = self.model_config.auto_map
|
||||||
|
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||||
|
else:
|
||||||
|
has_remote_code = False
|
||||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||||
has_remote_code = self.cfg.trust_remote_code
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
|
|||||||
@@ -119,18 +119,28 @@ def temp_dir():
|
|||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
|
||||||
original_fa2_forward = LlamaFlashAttention2.forward
|
original_fa2_forward = LlamaFlashAttention2.forward
|
||||||
|
original_trainer_inner_training_loop = (
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
original_trainer_training_step = Trainer.training_step
|
||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
LlamaFlashAttention2.forward = original_fa2_forward
|
LlamaFlashAttention2.forward = original_fa2_forward
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
original_trainer_inner_training_loop
|
||||||
|
)
|
||||||
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
|
||||||
# Reset other known monkeypatches
|
# Reset other known monkeypatches
|
||||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
|
("transformers",),
|
||||||
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||||
("transformers.trainer",),
|
("transformers.trainer", ["Trainer"]),
|
||||||
("transformers.loss.loss_utils",),
|
("transformers.loss.loss_utils",),
|
||||||
]
|
]
|
||||||
for module_name_tuple in modules_to_reset:
|
for module_name_tuple in modules_to_reset:
|
||||||
|
|||||||
Reference in New Issue
Block a user