Compare commits
11 Commits
activation
...
enable_tp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c98a4353 | ||
|
|
c760d2b815 | ||
|
|
2014f58181 | ||
|
|
b5f9dd44f2 | ||
|
|
b17b1aada7 | ||
|
|
85381b6b15 | ||
|
|
acde081321 | ||
|
|
e4c68a0cbc | ||
|
|
3855f5c3d3 | ||
|
|
5dd566dc63 | ||
|
|
42389c1f78 |
5
.github/workflows/tests-nightly.yml
vendored
5
.github/workflows/tests-nightly.yml
vendored
@@ -44,11 +44,6 @@ jobs:
|
|||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
- name: upgrade pip
|
|
||||||
run: |
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|||||||
58
examples/llama-3/fft-8b-tp.yml
Normal file
58
examples/llama-3/fft-8b-tp.yml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
tensor_parallel: 'auto'
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
73
examples/llama-3/lora-8b-tp.yml
Normal file
73
examples/llama-3/lora-8b-tp.yml
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
tensor_parallel: 'auto'
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
@@ -996,15 +996,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
def _evaluate(self, *args, **kwargs):
|
|
||||||
metrics = super()._evaluate(*args, **kwargs)
|
|
||||||
|
|
||||||
# cleanup memory after evals
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1328,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
# self.model =
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_ref(self):
|
def model_ref(self):
|
||||||
return self._model_ref
|
return self._model_ref
|
||||||
|
|||||||
@@ -1,170 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import inspect
|
|
||||||
import types
|
|
||||||
|
|
||||||
from torchtune.training import OffloadActivations
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
|
||||||
|
|
||||||
HF_MODEL_OUTPUTS = """
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
PATCHED_HF_MODEL_OUTPUTS = """
|
|
||||||
with self.act_offloading_ctx_manager:
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
LCE_MODEL_OUTPUTS = """
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
PATCHED_LCE_OUTPUTS = """
|
|
||||||
with self.act_offloading_ctx_manager:
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
HF_GA_FORWARD_1 = """
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
PATCHED_HF_GA_FORWARD_1 = """
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
|
||||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
HF_GA_FORWARD_2 = """
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
PATCHED_HF_GA_FORWARD_2 = """
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
|
||||||
""".lstrip()
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|
||||||
act_offloading_ctx_manager = contextlib.nullcontext()
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_forward(cls):
|
|
||||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
|
||||||
forward_source, _ = detab_code(forward_source)
|
|
||||||
cls.forward = types.MethodType(
|
|
||||||
compile(forward_source, "<forward>", "exec"), cls
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def enable_act_offloading(cls):
|
|
||||||
forward_source = inspect.getsource(cls.forward)
|
|
||||||
forward_source = forward_source.replace(
|
|
||||||
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
|
||||||
)
|
|
||||||
forward_source, _ = detab_code(forward_source)
|
|
||||||
# replace forward method with patched version
|
|
||||||
cls.forward = types.MethodType(
|
|
||||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
|
||||||
)
|
|
||||||
cls.act_offloading_ctx_manager = OffloadActivations()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
|
||||||
from liger_kernel.transformers.model.llama import (
|
|
||||||
lce_forward as llama_lce_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
if enable_act_offloading:
|
|
||||||
lce_source = inspect.getsource(llama_lce_forward)
|
|
||||||
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
|
||||||
# replace forward method with patched version
|
|
||||||
cls.forward = types.MethodType(
|
|
||||||
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
|
||||||
cls,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cls.forward = types.methodType(llama_lce_forward, cls)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def patch_hf_ga(cls):
|
|
||||||
# bugfix patch for gradient accumulation
|
|
||||||
forward_source = inspect.getsource(cls.forward)
|
|
||||||
forward_source = forward_source.replace(
|
|
||||||
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
|
||||||
)
|
|
||||||
forward_source = forward_source.replace(
|
|
||||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
|
||||||
)
|
|
||||||
forward_source, _ = detab_code(forward_source)
|
|
||||||
# replace forward method with patched version
|
|
||||||
cls.forward = types.MethodType(
|
|
||||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_auto_model():
|
|
||||||
from transformers import LlamaConfig
|
|
||||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
|
||||||
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
|
||||||
AxolotlLlamaForCausalLM.set_forward()
|
|
||||||
|
|
||||||
return AxolotlLlamaForCausalLM
|
|
||||||
@@ -66,7 +66,10 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
if (
|
||||||
|
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||||
|
and state.global_step == 1
|
||||||
|
):
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@@ -679,7 +679,6 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
activation_offloading: Optional[bool] = None
|
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -380,15 +380,6 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
if self.cfg.model_config_type == "llama":
|
|
||||||
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
|
||||||
|
|
||||||
AxolotlLlamaForCausalLM = replace_auto_model()
|
|
||||||
|
|
||||||
AxolotlLlamaForCausalLM.patch_hf_ga()
|
|
||||||
if self.cfg.activation_offloading:
|
|
||||||
AxolotlLlamaForCausalLM.enable_act_offloading()
|
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
patch_training_loop_for_fsdp,
|
patch_training_loop_for_fsdp,
|
||||||
@@ -1192,15 +1183,19 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
# self.apply_patches_to_model()
|
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.post_loading_set_env()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
def post_loading_set_env(self):
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
Reference in New Issue
Block a user