Compare commits
20 Commits
shampoo-lo
...
bbf5158e9c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bbf5158e9c | ||
|
|
ec70046a2b | ||
|
|
7fed41550e | ||
|
|
da3a941bc3 | ||
|
|
ad3c179a5a | ||
|
|
15e26b14eb | ||
|
|
33bbe9b222 | ||
|
|
1fddf45958 | ||
|
|
e42e319446 | ||
|
|
613f238e56 | ||
|
|
6b617a4fd5 | ||
|
|
6ac10de9ef | ||
|
|
1b8d439441 | ||
|
|
1ed351781a | ||
|
|
c2a48c3a1e | ||
|
|
415399b565 | ||
|
|
67c04133f2 | ||
|
|
4911d0952f | ||
|
|
1d7ab52161 | ||
|
|
fcdc6fee8b |
@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
||||
pip3 install flash-attn==2.6.3; \
|
||||
fi
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.1
|
||||
transformers==4.46.2
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.1.0
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.optim import Optimizer
|
||||
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
||||
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
||||
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
||||
|
||||
|
||||
class _ShampooBase(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size,
|
||||
quantization_bits,
|
||||
optimizer_state_class,
|
||||
):
|
||||
if lr <= 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if eps < 0.0:
|
||||
raise ValueError(f"Invalid eps value: {eps}")
|
||||
if update_freq < 1:
|
||||
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
update_freq=update_freq,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
self.block_size = block_size
|
||||
self.quantization_bits = quantization_bits
|
||||
self.optimizer_state_class = optimizer_state_class
|
||||
|
||||
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["momentum_buffer"] = self._new_buffer(grad, True)
|
||||
state["preconds"] = []
|
||||
state["inv_preconds"] = []
|
||||
for dim in grad.size():
|
||||
state["preconds"].append(
|
||||
self.optimizer_state_class.zeros(
|
||||
(dim, dim),
|
||||
signed=False,
|
||||
block_size=self.block_size,
|
||||
device=grad.device,
|
||||
)
|
||||
)
|
||||
state["inv_preconds"].append(
|
||||
torch.zeros((dim, dim), device=grad.device)
|
||||
)
|
||||
|
||||
state["step"] += 1
|
||||
beta = group["momentum"]
|
||||
weight_decay = group["weight_decay"]
|
||||
lr = group["lr"]
|
||||
eps = group["eps"]
|
||||
update_freq = group["update_freq"]
|
||||
|
||||
# Apply momentum
|
||||
if beta > 0:
|
||||
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
||||
grad = state["momentum_buffer"]
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay > 0:
|
||||
grad = grad.add(p.data, alpha=weight_decay)
|
||||
|
||||
# Preconditioning
|
||||
order = grad.ndimension()
|
||||
original_size = grad.size()
|
||||
for dim_id, dim in enumerate(grad.size()):
|
||||
precond = state["preconds"][dim_id]
|
||||
inv_precond = state["inv_preconds"][dim_id]
|
||||
|
||||
# Reshape grad
|
||||
grad = grad.transpose(0, dim_id).contiguous()
|
||||
transposed_size = grad.size()
|
||||
grad = grad.view(dim, -1)
|
||||
|
||||
grad_t = grad.t()
|
||||
|
||||
# Update preconditioner
|
||||
precond_fp32 = precond.dequantize()
|
||||
precond_update = grad @ grad_t
|
||||
precond_fp32.add_(precond_update)
|
||||
|
||||
# Quantize preconditioner back
|
||||
precond.copy_(precond_fp32)
|
||||
|
||||
# Update inverse preconditioner
|
||||
if state["step"] % update_freq == 0:
|
||||
inv_precond.copy_(
|
||||
self._compute_inv_precond(precond_fp32, eps, order)
|
||||
)
|
||||
|
||||
# Precondition grad
|
||||
if dim_id == order - 1:
|
||||
# Last dimension
|
||||
grad = grad_t @ inv_precond
|
||||
grad = grad.view(original_size)
|
||||
else:
|
||||
grad = inv_precond @ grad
|
||||
grad = grad.view(transposed_size)
|
||||
|
||||
# Update parameter
|
||||
p.data.add_(grad, alpha=-lr)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
||||
# Add eps for numerical stability
|
||||
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
||||
|
||||
# Compute matrix power
|
||||
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
||||
|
||||
return inv_precond
|
||||
|
||||
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
||||
# Compute matrix power using SVD
|
||||
u, s, v = torch.svd(matrix)
|
||||
s_pow = s.pow(power)
|
||||
return u @ torch.diag(s_pow) @ v.t()
|
||||
|
||||
# bring your own function to create zero-filled subclass
|
||||
@staticmethod
|
||||
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
||||
raise NotImplementedError
|
||||
|
||||
# follow bitsandbytes, only quantize tensors >= 4096 values
|
||||
# also wrap subclass in DTensor when needed
|
||||
def _new_buffer(self, p: Tensor, signed: bool):
|
||||
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
||||
if isinstance(p, DTensor):
|
||||
out = DTensor.from_local(
|
||||
local_tensor=self._subclass_zeros(
|
||||
p.to_local(), signed, self.block_size
|
||||
),
|
||||
device_mesh=p.device_mesh,
|
||||
placements=p.placements,
|
||||
run_check=False,
|
||||
)
|
||||
else:
|
||||
out = self._subclass_zeros(p, signed, self.block_size)
|
||||
else:
|
||||
out = torch.zeros_like(p)
|
||||
return out
|
||||
|
||||
|
||||
class Shampoo8bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8,
|
||||
optimizer_state_class=OptimState8bit,
|
||||
)
|
||||
|
||||
|
||||
class Shampoo4bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=128,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=4,
|
||||
optimizer_state_class=OptimState4bit,
|
||||
)
|
||||
|
||||
|
||||
class ShampooFp8(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8, # FP8 uses 8 bits
|
||||
optimizer_state_class=OptimStateFp8,
|
||||
)
|
||||
76
test.yml
Normal file
76
test.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
base_model: JackFram/llama-68m
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.5
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
@@ -63,6 +63,51 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_llama_wo_flce2(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"plugins": [
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
],
|
||||
"liger_rope": True,
|
||||
"liger_rms_norm": True,
|
||||
"liger_swiglu": True,
|
||||
"liger_cross_entropy": True,
|
||||
"liger_fused_linear_cross_entropy": False,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_llama_w_flce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user