support for latest transformers release 4.48.1 (#2256)
This commit is contained in:
@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
|||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||||
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.47.1
|
transformers==4.48.1
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.2.1
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.13.0
|
trl==0.13.0
|
||||||
|
|||||||
@@ -1079,6 +1079,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
|||||||
@@ -1,308 +0,0 @@
|
|||||||
"""
|
|
||||||
fix for FSDP gradient accumulation
|
|
||||||
see https://github.com/huggingface/transformers/pull/35128
|
|
||||||
"""
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from transformers import LlamaForCausalLM, Trainer
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
|
||||||
|
|
||||||
ORIGINAL_CONTEXT_CODE = """
|
|
||||||
with self.compute_loss_context_manager():
|
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_CONTEXT_CODE = """
|
|
||||||
with self.compute_loss_context_manager():
|
|
||||||
if self.model_accepts_loss_kwargs:
|
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
|
||||||
else:
|
|
||||||
loss = self.compute_loss(model, inputs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_LLAMA_FCLM_CODE = """
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
|
||||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_step_code() -> str:
|
|
||||||
training_step = inspect.getsource(
|
|
||||||
Trainer.training_step # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
return training_step
|
|
||||||
|
|
||||||
|
|
||||||
def check_training_step_is_patchable() -> bool:
|
|
||||||
training_step = get_training_step_code()
|
|
||||||
training_step, _ = detab_code(training_step)
|
|
||||||
return ORIGINAL_CONTEXT_CODE in training_step
|
|
||||||
|
|
||||||
|
|
||||||
def patch_training_step_for_ga():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the training loop for gradient accumulation
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
training_step = get_training_step_code()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
|
||||||
training_step, _ = detab_code(training_step)
|
|
||||||
if ORIGINAL_CONTEXT_CODE not in training_step:
|
|
||||||
return
|
|
||||||
# assert (
|
|
||||||
# ORIGINAL_CONTEXT_CODE in training_step
|
|
||||||
# ), "Original training_step code not found"
|
|
||||||
|
|
||||||
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
|
||||||
training_step = training_step.replace(
|
|
||||||
"def training_step(",
|
|
||||||
"def _fixed_training_step(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in training_step:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching training_step")
|
|
||||||
Trainer.training_step = ( # pylint: disable=protected-access
|
|
||||||
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_forward_code() -> str:
|
|
||||||
forward = inspect.getsource(
|
|
||||||
LlamaForCausalLM.forward # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def check_forward_is_patchable() -> bool:
|
|
||||||
forward = get_model_forward_code()
|
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
|
||||||
|
|
||||||
|
|
||||||
def patch_forward_for_ga():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the training loop for gradient accumulation
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
forward = get_model_forward_code()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
|
||||||
return
|
|
||||||
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
|
||||||
|
|
||||||
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
|
||||||
forward = forward.replace(
|
|
||||||
"def forward(",
|
|
||||||
"def _fixed_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
|
||||||
if item in forward:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching forward")
|
|
||||||
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
|
||||||
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i != len(batch_samples) - 1
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
with context():
|
|
||||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_TRAINER_CODE = """
|
|
||||||
disable_deepspeed_no_sync = (
|
|
||||||
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
|
||||||
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
|
|
||||||
)
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
with context():
|
|
||||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_loop_code() -> str:
|
|
||||||
training_loop = inspect.getsource(
|
|
||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
return training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def check_training_loop_is_patchable() -> bool:
|
|
||||||
training_loop = get_training_loop_code()
|
|
||||||
training_loop, _ = detab_code(training_loop)
|
|
||||||
return ORIGINAL_TRAINER_CODE in training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def patch_training_loop_for_deepspeed_0_16_x():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the training loop for deepspeed GA
|
|
||||||
|
|
||||||
see https://github.com/huggingface/transformers/pull/35157
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
training_loop = get_training_loop_code()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
training_loop
|
|
||||||
)
|
|
||||||
training_loop, _ = detab_code(training_loop)
|
|
||||||
if ORIGINAL_TRAINER_CODE not in training_loop:
|
|
||||||
return
|
|
||||||
|
|
||||||
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
|
||||||
training_loop = training_loop.replace(
|
|
||||||
"def _inner_training_loop(",
|
|
||||||
"def _fixed_inner_training_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in training_loop:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
|
||||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_flash_attention_forward():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
|
|
||||||
def proxy_flash_attention_forward(*args, **kwargs):
|
|
||||||
kwargs.pop("num_items_in_batch", None)
|
|
||||||
|
|
||||||
return _flash_attention_forward(*args, **kwargs)
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
|
|
||||||
proxy_flash_attention_forward
|
|
||||||
)
|
|
||||||
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
|
|
||||||
proxy_flash_attention_forward
|
|
||||||
)
|
|
||||||
67
src/axolotl/monkeypatch/transformers_fa_utils.py
Normal file
67
src/axolotl/monkeypatch/transformers_fa_utils.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
see https://github.com/huggingface/transformers/pull/35834
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fixed_fa_peft_integration_check(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
target_dtype: Optional[torch.dtype] = None,
|
||||||
|
preferred_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
PEFT usually casts the layer norms in float32 for training stability reasons
|
||||||
|
therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||||
|
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
target_dtype (`torch.dtype`, *optional*):
|
||||||
|
The dtype to convert the attention tensors to. Conversion can be ignored by
|
||||||
|
not providing the target dtype.
|
||||||
|
preferred_dtype (`torch.dtype`, *optional*):
|
||||||
|
The preferred dtype to convert the attention tensors to regardless of the
|
||||||
|
target dtype.
|
||||||
|
"""
|
||||||
|
if target_dtype is None and preferred_dtype is None:
|
||||||
|
return query, key, value
|
||||||
|
|
||||||
|
if preferred_dtype and target_dtype != preferred_dtype:
|
||||||
|
target_dtype = preferred_dtype
|
||||||
|
|
||||||
|
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
|
||||||
|
if any(module.dtype == torch.float32 for module in [query, key, value]):
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.to(target_dtype)
|
||||||
|
key = key.to(target_dtype)
|
||||||
|
value = value.to(target_dtype)
|
||||||
|
|
||||||
|
return query, key, value
|
||||||
|
|
||||||
|
|
||||||
|
def patch_fa_peft_integration():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
|
||||||
|
fixed_fa_peft_integration_check, preferred_dtype=None
|
||||||
|
)
|
||||||
@@ -380,23 +380,19 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.adapter:
|
||||||
|
from axolotl.monkeypatch.transformers_fa_utils import (
|
||||||
|
patch_fa_peft_integration,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_fa_peft_integration()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
if self.cfg.model_config_type == "llama":
|
|
||||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
|
||||||
patch_flash_attention_forward,
|
|
||||||
patch_forward_for_ga,
|
|
||||||
patch_training_step_for_ga,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_flash_attention_forward()
|
|
||||||
patch_forward_for_ga()
|
|
||||||
patch_training_step_for_ga()
|
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class TestMultiGPULlama:
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -201,6 +203,7 @@ class TestMultiGPULlama:
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,8 +226,12 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loss_threshold = 2.3
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs",
|
||||||
|
"train/train_loss",
|
||||||
|
loss_threshold,
|
||||||
|
"Train Loss is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
@@ -275,6 +282,7 @@ class TestMultiGPULlama:
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -297,8 +305,12 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loss_threshold = 2.3
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs",
|
||||||
|
"train/train_loss",
|
||||||
|
loss_threshold,
|
||||||
|
"Train Loss is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
|
||||||
"MixtralFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
|
||||||
)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
|
|
||||||
assert (
|
|
||||||
"MixtralFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_mistral_multipack(self, temp_dir):
|
def test_mistral_multipack(self, temp_dir):
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
reason="Unsloth integration will be broken going into latest transformers"
|
reason="Unsloth integration will be broken going into latest transformers"
|
||||||
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
|
|||||||
"""Unsloth monkeypatch integration tests."""
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
def test_is_self_attn_patchable(self):
|
def test_is_self_attn_patchable(self):
|
||||||
|
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||||
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
check_self_attn_is_patchable(),
|
check_self_attn_is_patchable(),
|
||||||
|
|||||||
0
tests/e2e/solo/__init__.py
Normal file
0
tests/e2e/solo/__init__.py
Normal file
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
|
||||||
check_forward_is_patchable,
|
|
||||||
check_training_step_is_patchable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerGAIntegration(unittest.TestCase):
|
|
||||||
"""llama monkeypatch integration tests."""
|
|
||||||
|
|
||||||
def test_train_step_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_training_step_is_patchable(),
|
|
||||||
"HF transformers Trainer.training_step has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_model_forward_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_forward_is_patchable(),
|
|
||||||
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user