Fix: Gradient Accumulation issue (#1980)
* feat: support new arg num_items_in_batch * use kwargs to manage extra unknown kwargs for now * upgrade against upstream transformers main * make sure trl is on latest too * fix for upgraded trl * fix: handle trl and transformer signature change * feat: update trl to handle transformer signature * RewardDataCollatorWithPadding no longer has max_length * handle updated signature for tokenizer vs processor class * invert logic for tokenizer vs processor class * processing_class, not processor class * also handle processing class in dpo * handle model name w model card creation * upgrade transformers and add a loss check test * fix install of tbparse requirements * make sure to add tbparse to req * feat: revert kwarg to positional kwarg to be explicit --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
id: tag
|
id: tag
|
||||||
|
|||||||
3
.github/workflows/tests-nightly.yml
vendored
3
.github/workflows/tests-nightly.yml
vendored
@@ -47,13 +47,14 @@ jobs:
|
|||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
@@ -36,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install -r requirements-tests.txt
|
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
tbparse
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.45.2
|
transformers==4.46.0
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.0.1
|
accelerate==1.0.1
|
||||||
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.9.6
|
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import abc
|
|||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
PreTrainedModel,
|
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -666,7 +666,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(
|
||||||
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
|
):
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
# if self.args.sample_packing:
|
# if self.args.sample_packing:
|
||||||
# labels = inputs.pop("labels")
|
# labels = inputs.pop("labels")
|
||||||
@@ -674,8 +676,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
return self.orpo_compute_loss(
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=return_outputs,
|
||||||
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
)
|
||||||
|
return super().compute_loss(
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=return_outputs,
|
||||||
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
@@ -771,7 +783,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
).squeeze(2)
|
).squeeze(2)
|
||||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||||
|
|
||||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
def orpo_compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||||
inputs,
|
inputs,
|
||||||
label_pad_token=-100,
|
label_pad_token=-100,
|
||||||
@@ -898,6 +916,7 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
|||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
input_ids = inputs.pop("input_ids")
|
input_ids = inputs.pop("input_ids")
|
||||||
lm_logits = model(input_ids).logits
|
lm_logits = model(input_ids).logits
|
||||||
@@ -1005,18 +1024,32 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
self,
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = super().tokenize_row(feature, model=model)
|
res = super().tokenize_row(
|
||||||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
)
|
||||||
|
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||||
for key in res.keys():
|
for key in res.keys():
|
||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
loss: torch.Tensor = super().training_step(model, inputs)
|
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
@@ -1667,12 +1700,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
)
|
)
|
||||||
|
sig = inspect.signature(trainer_cls)
|
||||||
|
if "processing_class" in sig.parameters.keys():
|
||||||
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
else:
|
||||||
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
@@ -1713,6 +1751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = RewardDataCollatorWithPadding
|
||||||
|
if "max_length" in kwargs:
|
||||||
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
@@ -1915,7 +1955,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
@@ -1927,11 +1967,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
sig = inspect.signature(trainer_cls)
|
||||||
|
if "processing_class" in sig.parameters.keys():
|
||||||
|
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
else:
|
||||||
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
loss = fast_cross_entropy_loss(
|
|
||||||
logits = shift_logits,
|
|
||||||
labels = shift_labels,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
ORIGINAL_QKV_CODE = """
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
@@ -80,12 +60,6 @@ def get_forward_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def check_cel_is_patchable() -> bool:
|
|
||||||
forward = get_forward_code()
|
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
return ORIGINAL_CEL_CODE in forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
def get_self_attn_code() -> str:
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
return forward
|
return forward
|
||||||
@@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||||
|
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
||||||
|
|
||||||
|
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
vocab_size: int, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch: int = None,
|
||||||
|
ignore_index: int = -100, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
loss = fast_cross_entropy_loss(
|
||||||
|
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
if model_type == "llama":
|
if model_type == "llama":
|
||||||
forward = get_forward_code()
|
from transformers.loss import loss_utils
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
|
||||||
|
|
||||||
forward = forward.replace(
|
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
||||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
|
||||||
)
|
|
||||||
forward = forward.replace(
|
|
||||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
|
||||||
forward = forward.replace(
|
|
||||||
"def forward(",
|
|
||||||
"def fast_cross_entropy_loss_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
|
||||||
if item in forward:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported model type")
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|||||||
@@ -260,8 +260,10 @@ def train(
|
|||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
try:
|
try:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(
|
||||||
except AttributeError:
|
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
||||||
|
)
|
||||||
|
except (AttributeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
@@ -1,22 +1,12 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import (
|
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||||
check_cel_is_patchable,
|
|
||||||
check_self_attn_is_patchable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnslothIntegration(unittest.TestCase):
|
class TestUnslothIntegration(unittest.TestCase):
|
||||||
"""Unsloth monkeypatch integration tests."""
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
def test_is_cel_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_cel_is_patchable(),
|
|
||||||
"HF transformers loss code has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_is_self_attn_patchable(self):
|
def test_is_self_attn_patchable(self):
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
74
tests/e2e/test_packing_loss.py
Normal file
74
tests/e2e/test_packing_loss.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for packed training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import most_recent_subdir, with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPackedLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Packed training of llama models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_loss_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||||
Reference in New Issue
Block a user