diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e65..7f16ec378 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -996,6 +996,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer): os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) + def _evaluate(self, *args, **kwargs): + metrics = super()._evaluate(*args, **kwargs) + + # cleanup memory after evals + gc.collect() + torch.cuda.empty_cache() + + return metrics + class AxolotlMambaTrainer(AxolotlTrainer): """ diff --git a/src/axolotl/monkeypatch/models/__init__.py b/src/axolotl/monkeypatch/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/llama/modeling_llama.py b/src/axolotl/monkeypatch/models/llama/modeling_llama.py new file mode 100644 index 000000000..c0a20fbd3 --- /dev/null +++ b/src/axolotl/monkeypatch/models/llama/modeling_llama.py @@ -0,0 +1,157 @@ +import contextlib +import inspect +import types + +from torchtune.training import OffloadActivations +from transformers import LlamaConfig, LlamaForCausalLM + +HF_MODEL_OUTPUTS = """ + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) +""".lstrip() + +PATCHED_HF_MODEL_OUTPUTS = """ + with self.act_offloading_ctx_manager: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) +""".lstrip() + +LCE_MODEL_OUTPUTS = """ + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) +""".lstrip() + +PATCHED_LCE_OUTPUTS = """ + with self.act_offloading_ctx_manager: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) +""".lstrip() + +HF_GA_FORWARD_1 = """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) +""".lstrip() + +PATCHED_HF_GA_FORWARD_1 = """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention + num_items_in_batch = kwargs.pop("num_items_in_batch", None) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) +""".lstrip() + +HF_GA_FORWARD_2 = """ + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) +""".lstrip() + +PATCHED_HF_GA_FORWARD_2 = """ + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs) +""".lstrip() + + +class AxolotlLlamaForCausalLM(LlamaForCausalLM): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.act_offloading_ctx_manager = contextlib.nullcontext() + + forward_source = inspect.getsource(LlamaForCausalLM.forward) + self.forward = types.MethodType( + compile(forward_source, "", "exec"), self + ) + + def enable_act_offloading(self): + self.act_offloading_ctx_manager = OffloadActivations() + + forward_source = inspect.getsource(self.forward) + forward_source = forward_source.replace( + HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS + ) + # replace forward method with patched version + self.forward = types.MethodType( + compile(forward_source, "", "exec"), self + ) + + def enable_liger_fce(self, enable_act_offloading=True): + from liger_kernel.transformers.model.llama import ( + lce_forward as llama_lce_forward, + ) + + if enable_act_offloading: + lce_source = inspect.getsource(llama_lce_forward) + lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS) + # replace forward method with patched version + self.forward = types.MethodType( + compile(lce_source, "", "exec"), + self, + ) + else: + self.forward = types.methodType(llama_lce_forward, self) + + def patch_hf_ga(self): + # bugfix patch for gradient accumulation + forward_source = inspect.getsource(self.forward) + forward_source = forward_source.replace( + HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1 + ) + forward_source = forward_source.replace( + HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2 + ) + # replace forward method with patched version + self.forward = types.MethodType( + compile(forward_source, "", "exec"), self + ) + + +def replace_auto_model(): + from transformers import LlamaConfig + from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING + + MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3671e1bb9..ca200ca69 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -679,6 +679,7 @@ class AxolotlInputConfig( default=False ) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None + activation_offloading: Optional[bool] = None unfrozen_parameters: Optional[List[str]] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 11f4c6d0f..4db9107c7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1183,6 +1183,8 @@ class ModelLoader: self.apply_lora_patch() + # self.apply_patches_to_model() + for _ in range(3): gc.collect() torch.cuda.empty_cache()