Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
f8bb4185bc skip s2 attention test due to timeout 2024-04-08 18:33:33 -04:00
Wing Lian
2fa65b9599 ignore issues with calculating # params when printing (#1493) 2024-04-08 11:04:22 -04:00
xzuyn
9430b6e868 Remove validate_quantized_dora (#1485)
DoRA with quantized layers is supported with PEFT 0.10.0
2024-04-08 01:25:23 -04:00
Wing Lian
934fc851da drop empty token from beginning if tokenizer has no bos_token (in the case of qwen) (#1490) 2024-04-06 19:55:19 -07:00
NanoCode012
bda48f0150 fix: reduce sample_packing warning (#1484) 2024-04-06 21:04:07 +09:00
4 changed files with 22 additions and 15 deletions

View File

@@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -802,6 +803,15 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
class TrainerBuilderBase(abc.ABC):
"""

View File

@@ -242,17 +242,6 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@model_validator(mode="before")
@classmethod
def validate_quantized_dora(cls, data):
if data.get("peft_use_dora") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"`peft_use_dora` is not currently compatible with quantized weights."
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -664,8 +653,8 @@ class AxolotlInputConfig(
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
raise ValueError(
"sample_packing requires flash_attention or sdp_attention to be set to true"
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
)
return data

View File

@@ -459,7 +459,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"):
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -902,7 +902,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
model = get_peft_model(model, lora_config)
if rank == 0:
model.print_trainable_parameters()
try:
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("Skipping test due to timeout.")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
"""
Test case for Llama models using S2 Attn