Compare commits
4 Commits
q-galore
...
dpo-spawn-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e86dd76154 | ||
|
|
5f58555bd0 | ||
|
|
cfc533a7f7 | ||
|
|
e1725aef2b |
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -57,6 +57,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
@@ -99,7 +103,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from abc import abstractmethod
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from multiprocessing import set_start_method
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Type, Union
|
from typing import Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
@@ -290,6 +291,18 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
@@ -1758,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|
||||||
|
# prevents multiprocessing issues for datasets on multiple GPUs
|
||||||
|
set_start_method("spawn")
|
||||||
|
|
||||||
return dpo_trainer
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
|
|||||||
set_module_name(model, name, qkv)
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_rms_norm():
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
patch_llama_cross_entropy()
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
try:
|
patch_llama_rms_norm()
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
|
||||||
"""Patched LLamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.warning(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
class FusedAttention(LlamaAttention):
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"llama",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
@@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
)
|
)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "llama":
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
# torch_version = torch.__version__.split(".")
|
||||||
|
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
# if torch_major == 2 and torch_minor >= 2:
|
||||||
|
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
|
|||||||
@@ -1112,6 +1112,31 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("either datasets or pretraining_dataset is required")
|
raise ValueError("either datasets or pretraining_dataset is required")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_xentropy_patch_conflicts(cls, data):
|
||||||
|
if data.get("flash_attn_cross_entropy") and data.get(
|
||||||
|
"unsloth_cross_entropy_loss"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_qlora_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
@@ -1163,3 +1188,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("deepspeed") and data.get("fsdp"):
|
if data.get("deepspeed") and data.get("fsdp"):
|
||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_multigpu_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
capabilities = data.get("capabilities")
|
||||||
|
if capabilities and capabilities.get("num_gpus") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -347,6 +347,27 @@ def load_model(
|
|||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model:
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_llama_cross_entropy,
|
||||||
|
patch_llama_rms_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_cross_entropy:
|
||||||
|
patch_llama_cross_entropy()
|
||||||
|
if cfg.flash_attn_rms_norm:
|
||||||
|
patch_llama_rms_norm()
|
||||||
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
integrate_cross_entropy_loss_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch()
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora()
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user