Compare commits

..

11 Commits

Author SHA1 Message Date
bursteratom
60c98a4353 stuff 2024-12-13 15:44:51 -05:00
bursteratom
c760d2b815 test accelerator 2024-12-12 12:29:35 -05:00
bursteratom
2014f58181 set os environ RANK 2024-12-11 11:45:07 -05:00
bursteratom
b5f9dd44f2 set os environ RANK 2024-12-11 11:40:20 -05:00
bursteratom
b17b1aada7 initialise process group for tp 2024-12-11 11:37:21 -05:00
bursteratom
85381b6b15 initialise process group for tp 2024-12-11 11:35:16 -05:00
bursteratom
acde081321 test lora tp 2024-12-11 11:19:34 -05:00
bursteratom
e4c68a0cbc test lora tp 2024-12-11 11:11:52 -05:00
bursteratom
3855f5c3d3 tp example tp auto 2024-12-11 11:03:39 -05:00
bursteratom
5dd566dc63 tp example 2024-12-11 11:01:23 -05:00
bursteratom
42389c1f78 enable tensor parallel 2024-12-11 10:38:14 -05:00
9 changed files with 146 additions and 198 deletions

View File

@@ -44,11 +44,6 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu

View File

@@ -0,0 +1,58 @@
base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,73 @@
base_model: NousResearch/Meta-Llama-3.1-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -996,15 +996,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def _evaluate(self, *args, **kwargs):
metrics = super()._evaluate(*args, **kwargs)
# cleanup memory after evals
gc.collect()
torch.cuda.empty_cache()
return metrics
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -1328,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property
def model_ref(self):
return self._model_ref

View File

@@ -1,170 +0,0 @@
import contextlib
import inspect
import types
from torchtune.training import OffloadActivations
from transformers import LlamaConfig, LlamaForCausalLM
from axolotl.monkeypatch.unsloth_ import detab_code
HF_MODEL_OUTPUTS = """
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
""".lstrip()
PATCHED_HF_MODEL_OUTPUTS = """
with self.act_offloading_ctx_manager:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
""".lstrip()
LCE_MODEL_OUTPUTS = """
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
""".lstrip()
PATCHED_LCE_OUTPUTS = """
with self.act_offloading_ctx_manager:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
""".lstrip()
HF_GA_FORWARD_1 = """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
""".lstrip()
PATCHED_HF_GA_FORWARD_1 = """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
""".lstrip()
HF_GA_FORWARD_2 = """
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
""".lstrip()
PATCHED_HF_GA_FORWARD_2 = """
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
""".lstrip()
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
act_offloading_ctx_manager = contextlib.nullcontext()
def __init__(self, config: LlamaConfig):
super().__init__(config)
@classmethod
def set_forward(cls):
forward_source = inspect.getsource(LlamaForCausalLM.forward)
forward_source, _ = detab_code(forward_source)
cls.forward = types.MethodType(
compile(forward_source, "<forward>", "exec"), cls
)
@classmethod
def enable_act_offloading(cls):
forward_source = inspect.getsource(cls.forward)
forward_source = forward_source.replace(
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
)
forward_source, _ = detab_code(forward_source)
# replace forward method with patched version
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
)
cls.act_offloading_ctx_manager = OffloadActivations()
@classmethod
def enable_liger_fce(cls, enable_act_offloading=True):
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
if enable_act_offloading:
lce_source = inspect.getsource(llama_lce_forward)
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
# replace forward method with patched version
cls.forward = types.MethodType(
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
cls,
)
else:
cls.forward = types.methodType(llama_lce_forward, cls)
@classmethod
def patch_hf_ga(cls):
# bugfix patch for gradient accumulation
forward_source = inspect.getsource(cls.forward)
forward_source = forward_source.replace(
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
)
forward_source = forward_source.replace(
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
)
forward_source, _ = detab_code(forward_source)
# replace forward method with patched version
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
)
def replace_auto_model():
from transformers import LlamaConfig
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
AxolotlLlamaForCausalLM.set_forward()
return AxolotlLlamaForCausalLM

View File

@@ -66,7 +66,10 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
control.should_evaluate = True
return control

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@@ -679,7 +679,6 @@ class AxolotlInputConfig(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
activation_offloading: Optional[bool] = None
unfrozen_parameters: Optional[List[str]] = None

View File

@@ -380,15 +380,6 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
AxolotlLlamaForCausalLM = replace_auto_model()
AxolotlLlamaForCausalLM.patch_hf_ga()
if self.cfg.activation_offloading:
AxolotlLlamaForCausalLM.enable_act_offloading()
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
@@ -1192,15 +1183,19 @@ class ModelLoader:
self.apply_lora_patch()
# self.apply_patches_to_model()
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling
return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model(
cfg: DictDefault,