Compare commits
5 Commits
sequence-p
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed |
@@ -996,6 +996,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def _evaluate(self, *args, **kwargs):
|
||||||
|
metrics = super()._evaluate(*args, **kwargs)
|
||||||
|
|
||||||
|
# cleanup memory after evals
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
HF_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_MODEL_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
LCE_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_LCE_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_forward(cls):
|
||||||
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<forward>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_act_offloading(cls):
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
|
)
|
||||||
|
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
|
from liger_kernel.transformers.model.llama import (
|
||||||
|
lce_forward as llama_lce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_act_offloading:
|
||||||
|
lce_source = inspect.getsource(llama_lce_forward)
|
||||||
|
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_hf_ga(cls):
|
||||||
|
# bugfix patch for gradient accumulation
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||||
|
)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_auto_model():
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
AxolotlLlamaForCausalLM.set_forward()
|
||||||
|
|
||||||
|
return AxolotlLlamaForCausalLM
|
||||||
@@ -679,6 +679,7 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
activation_offloading: Optional[bool] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||||
|
if self.cfg.activation_offloading:
|
||||||
|
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
patch_training_loop_for_fsdp,
|
patch_training_loop_for_fsdp,
|
||||||
@@ -1183,6 +1192,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
|
# self.apply_patches_to_model()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user