Compare commits
1 Commits
pytest-ski
...
fsdp-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
744f7082f5 |
@@ -23,7 +23,6 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
@@ -803,15 +802,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def tokenize_row(
|
||||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(feature, model=model)
|
||||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
return res
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -242,6 +242,17 @@ class LoraConfig(BaseModel):
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_quantized_dora(cls, data):
|
||||
if data.get("peft_use_dora") and (
|
||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||
):
|
||||
raise ValueError(
|
||||
"`peft_use_dora` is not currently compatible with quantized weights."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
@@ -653,8 +664,8 @@ class AxolotlInputConfig(
|
||||
and not data.get("flash_attention")
|
||||
and not data.get("sdp_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
|
||||
raise ValueError(
|
||||
"sample_packing requires flash_attention or sdp_attention to be set to true"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -459,7 +459,7 @@ def load_model(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if not cfg.deepspeed:
|
||||
if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"):
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
@@ -902,12 +902,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if rank == 0:
|
||||
try:
|
||||
model.print_trainable_parameters()
|
||||
except AttributeError as exc:
|
||||
LOG.warning(
|
||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||
)
|
||||
model.print_trainable_parameters()
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -21,7 +19,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip("Skipping test due to timeout.")
|
||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using S2 Attn
|
||||
|
||||
Reference in New Issue
Block a user