Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
da154e6d56 support for json data as completion 2023-11-25 16:05:04 -05:00
50 changed files with 119 additions and 479 deletions

View File

@@ -612,12 +612,6 @@ eval_sample_packing:
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
@@ -665,8 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_name: # Set the name of your wandb run wandb_run_id: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
@@ -701,9 +694,6 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
@@ -962,7 +952,7 @@ wandb_mode:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```

View File

@@ -24,6 +24,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,6 +28,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,6 +32,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: btlm-out output_dir: btlm-out

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
batch_size: 4 batch_size: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2

View File

@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -62,9 +62,6 @@ logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: true flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 0.05
eval_table_size: eval_table_size:

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./lora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./pythia-12b output_dir: ./pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -53,7 +53,7 @@ resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 0.05

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -53,7 +53,7 @@ resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 0.05

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8

View File

@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out

View File

@@ -2,15 +2,14 @@
auto-gptq==0.5.1 auto-gptq==0.5.1
packaging packaging
peft==0.6.0 peft==0.6.0
transformers==4.35.2 transformers==4.35.1
tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.24.1 accelerate==0.24.1
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets>=2.15.0 datasets>=2.14.0
flash-attn==2.3.3 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
@@ -30,7 +29,7 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
gradio==3.50.2 gradio
tensorboard tensorboard
# remote filesystems # remote filesystems

View File

@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -72,7 +71,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload() model = model.merge_and_unload()
model.to(dtype=cfg.torch_dtype) model.to(dtype=torch.float16)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg) validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg) normalize_config(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)

View File

@@ -25,7 +25,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
@@ -431,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -647,7 +643,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None self.cfg.wandb_run_id if self.cfg.use_wandb else None
) )
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"

View File

@@ -1,168 +0,0 @@
# Adapted from Unsloth
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
import triton
import triton.language as tl
import torch
MAX_FUSED_SIZE = 65536
def calculate_settings(n):
BLOCK_SIZE = triton.next_power_of_2(n)
# CUDA only supports 65536 - 2^16 threads per block
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
@triton.jit
def _cross_entropy_forward(logits_ptr, logits_row_stride,
loss_ptr,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
loss_ptr += row_idx
lse_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
max_logits = tl.max(logits, 0)
# Maximum stops overflow
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr, lse)
if label_idx != -100:
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = lse - logits_label
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
@triton.jit
def _cross_entropy_backward(logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
dloss_ptr += row_idx * dloss_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
n_rows, n_cols = logits.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
n_cols,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(logits, logsumexp, labels)
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
_cross_entropy_backward[(n_rows,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
n_cols,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass

View File

@@ -13,20 +13,16 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention, MistralAttention as OriginalMistralAttention,
) )
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import (
MistralForCausalLM as OriginalMistralForCausalLM,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
LOG = logging.getLogger("axolotl.monkeypatch.mistral") LOG = logging.getLogger("axolotl.monkeypatch.mistral")
@@ -40,9 +36,6 @@ def replace_mistral_attn_with_flash_attn(
transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward flashattn_forward
) )
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
mistral_causallm_forward
)
if packed: if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer MistralDecoderLayer
@@ -648,71 +641,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
outputs += (present_key_value,) outputs += (present_key_value,)
return outputs return outputs
def mistral_causallm_forward(
self: OriginalMistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
shift_labels = shift_labels.to(shift_logits.device)
# FAST CROSS ENTROPY
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return next(iter(self.prompter.build_prompt(instruction, input, response))) return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
"""
Strategy to return the stringified JSON of the entire row as the training data
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
json.dumps(prompt),
"",
"",
)
class CompletionPrompter: class CompletionPrompter:
""" """
Prompter for completion Prompter for completion
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy( strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(), CompletionPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, True,
cfg.sequence_len, cfg.sequence_len,
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat.field = ds_cfg["field"] strat.field = ds_cfg["field"]
return strat return strat
def load_json(tokenizer, cfg):
strat = CompletionJSONPromptTokenizationStrategy(
CompletionPrompter(),
tokenizer,
True,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -124,36 +124,6 @@ class GPUStatsCallback(
return control return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer): def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy") accuracy = evaluate.load("accuracy")
abcd_idx = [ abcd_idx = [

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device() cfg.device = get_device()
if cfg.world_size == 1: if cfg.world_size == 1:
cfg.device_map = cfg.device_map or "auto" cfg.device_map = "auto"
else: else:
if cfg.device.startswith("cuda"): if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()} cfg.device_map = {"": torch.cuda.current_device()}
@@ -397,13 +397,6 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
) )
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -28,27 +28,6 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
@@ -59,8 +38,6 @@ def load_model_config(cfg):
for key, val in cfg.model_config.items(): for key, val in cfg.model_config.items():
setattr(model_config, key, val) setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config return model_config
@@ -239,7 +216,6 @@ def load_model(
model_kwargs = {} model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision: if cfg.model_revision:
@@ -436,22 +412,15 @@ def load_model(
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or ( if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit cfg.adapter == "qlora" and cfg.load_in_4bit
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
if not skip_prepare_model_for_kbit_training: model = prepare_model_for_kbit_training(
model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing
model, use_gradient_checkpointing=cfg.gradient_checkpointing )
)
needs_fa2_dtype = True needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to

View File

@@ -267,14 +267,12 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg):
def setup_wandb_env_vars(cfg: DictDefault): if cfg.wandb_mode and cfg.wandb_mode == "offline":
for key in cfg.keys(): os.environ["WANDB_MODE"] = cfg.wandb_mode
if key.startswith("wandb_"): elif cfg.wandb_project and len(cfg.wandb_project) > 0:
value = cfg.get(key, "") os.environ["WANDB_PROJECT"] = cfg.wandb_project
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True cfg.use_wandb = True
os.environ.pop("WANDB_DISABLED", None) # Remove if present if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else: else:
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,7 +1,6 @@
"""Module for testing the validation module""" """Module for testing the validation module"""
import logging import logging
import os
import unittest import unittest
from typing import Optional from typing import Optional
@@ -9,7 +8,6 @@ import pytest
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):
@@ -681,83 +679,3 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"