Compare commits
5 Commits
fsdp-fix
...
pytest-ski
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8bb4185bc | ||
|
|
2fa65b9599 | ||
|
|
9430b6e868 | ||
|
|
934fc851da | ||
|
|
bda48f0150 |
@@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -802,6 +803,15 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
def tokenize_row(
|
||||||
|
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||||
|
) -> Dict:
|
||||||
|
res = super().tokenize_row(feature, model=model)
|
||||||
|
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||||
|
for key in res.keys():
|
||||||
|
res[key] = res[key][1:]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -242,17 +242,6 @@ class LoraConfig(BaseModel):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_quantized_dora(cls, data):
|
|
||||||
if data.get("peft_use_dora") and (
|
|
||||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"`peft_use_dora` is not currently compatible with quantized weights."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
@@ -664,8 +653,8 @@ class AxolotlInputConfig(
|
|||||||
and not data.get("flash_attention")
|
and not data.get("flash_attention")
|
||||||
and not data.get("sdp_attention")
|
and not data.get("sdp_attention")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
LOG.warning(
|
||||||
"sample_packing requires flash_attention or sdp_attention to be set to true"
|
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -902,7 +902,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
model.print_trainable_parameters()
|
try:
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
except AttributeError as exc:
|
||||||
|
LOG.warning(
|
||||||
|
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||||
|
)
|
||||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Skipping test due to timeout.")
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using S2 Attn
|
Test case for Llama models using S2 Attn
|
||||||
|
|||||||
Reference in New Issue
Block a user