Compare commits
7 Commits
shampoo-lo
...
upgrade_li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bf20f990 | ||
|
|
bb648cbc63 | ||
|
|
8b0bca4842 | ||
|
|
d36baf44b1 | ||
|
|
16c8140d20 | ||
|
|
21c25cf7bc | ||
|
|
32288a5d3c |
@@ -121,7 +121,7 @@ Features:
|
|||||||
|
|
||||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||||
|
|
||||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
|
# fastchat conversation (deprecation soon, use chat_template)
|
||||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
- path: ...
|
- path: ...
|
||||||
type: sharegpt
|
type: sharegpt
|
||||||
|
|||||||
@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
|
||||||
pip3 install flash-attn==2.6.3; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ strict: false
|
|||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_glu_activation: true
|
liger_swiglu: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
chat_template: deepseek_v2
|
chat_template: deepseek_v2
|
||||||
|
|||||||
@@ -9,27 +9,21 @@ liger_fused_linear_cross_entropy: true
|
|||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
chat_template: llama3
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: mlabonne/FineTome-100k
|
- path: tatsu-lab/alpaca
|
||||||
type: chat_template
|
type: alpaca
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project: check_liger_hf_GA_llama_fix-3
|
||||||
wandb_entity:
|
wandb_entity: axolotl-ai
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name: pr/fix333-tr4.46.1
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ peft==0.13.2
|
|||||||
transformers==4.46.1
|
transformers==4.46.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.0.1
|
||||||
datasets==3.0.1
|
datasets==3.0.1
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
@@ -33,8 +33,8 @@ gradio==3.50.2
|
|||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=3.1.0
|
||||||
liger-kernel==0.4.0
|
liger-kernel==0.3.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, **kwargs):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# make sure the checkpoint dir exists, since trainer is flakey
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
|||||||
@@ -1,250 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
|
||||||
|
|
||||||
|
|
||||||
class _ShampooBase(Optimizer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size,
|
|
||||||
quantization_bits,
|
|
||||||
optimizer_state_class,
|
|
||||||
):
|
|
||||||
if lr <= 0.0:
|
|
||||||
raise ValueError(f"Invalid learning rate: {lr}")
|
|
||||||
if momentum < 0.0:
|
|
||||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
||||||
if weight_decay < 0.0:
|
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
||||||
if eps < 0.0:
|
|
||||||
raise ValueError(f"Invalid eps value: {eps}")
|
|
||||||
if update_freq < 1:
|
|
||||||
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
|
||||||
|
|
||||||
defaults = dict(
|
|
||||||
lr=lr,
|
|
||||||
momentum=momentum,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
eps=eps,
|
|
||||||
update_freq=update_freq,
|
|
||||||
)
|
|
||||||
super().__init__(params, defaults)
|
|
||||||
self.block_size = block_size
|
|
||||||
self.quantization_bits = quantization_bits
|
|
||||||
self.optimizer_state_class = optimizer_state_class
|
|
||||||
|
|
||||||
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group["params"]:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
grad = p.grad.data
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
# State initialization
|
|
||||||
if len(state) == 0:
|
|
||||||
state["step"] = 0
|
|
||||||
state["momentum_buffer"] = self._new_buffer(grad, True)
|
|
||||||
state["preconds"] = []
|
|
||||||
state["inv_preconds"] = []
|
|
||||||
for dim in grad.size():
|
|
||||||
state["preconds"].append(
|
|
||||||
self.optimizer_state_class.zeros(
|
|
||||||
(dim, dim),
|
|
||||||
signed=False,
|
|
||||||
block_size=self.block_size,
|
|
||||||
device=grad.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
state["inv_preconds"].append(
|
|
||||||
torch.zeros((dim, dim), device=grad.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
state["step"] += 1
|
|
||||||
beta = group["momentum"]
|
|
||||||
weight_decay = group["weight_decay"]
|
|
||||||
lr = group["lr"]
|
|
||||||
eps = group["eps"]
|
|
||||||
update_freq = group["update_freq"]
|
|
||||||
|
|
||||||
# Apply momentum
|
|
||||||
if beta > 0:
|
|
||||||
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
|
||||||
grad = state["momentum_buffer"]
|
|
||||||
|
|
||||||
# Apply weight decay
|
|
||||||
if weight_decay > 0:
|
|
||||||
grad = grad.add(p.data, alpha=weight_decay)
|
|
||||||
|
|
||||||
# Preconditioning
|
|
||||||
order = grad.ndimension()
|
|
||||||
original_size = grad.size()
|
|
||||||
for dim_id, dim in enumerate(grad.size()):
|
|
||||||
precond = state["preconds"][dim_id]
|
|
||||||
inv_precond = state["inv_preconds"][dim_id]
|
|
||||||
|
|
||||||
# Reshape grad
|
|
||||||
grad = grad.transpose(0, dim_id).contiguous()
|
|
||||||
transposed_size = grad.size()
|
|
||||||
grad = grad.view(dim, -1)
|
|
||||||
|
|
||||||
grad_t = grad.t()
|
|
||||||
|
|
||||||
# Update preconditioner
|
|
||||||
precond_fp32 = precond.dequantize()
|
|
||||||
precond_update = grad @ grad_t
|
|
||||||
precond_fp32.add_(precond_update)
|
|
||||||
|
|
||||||
# Quantize preconditioner back
|
|
||||||
precond.copy_(precond_fp32)
|
|
||||||
|
|
||||||
# Update inverse preconditioner
|
|
||||||
if state["step"] % update_freq == 0:
|
|
||||||
inv_precond.copy_(
|
|
||||||
self._compute_inv_precond(precond_fp32, eps, order)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Precondition grad
|
|
||||||
if dim_id == order - 1:
|
|
||||||
# Last dimension
|
|
||||||
grad = grad_t @ inv_precond
|
|
||||||
grad = grad.view(original_size)
|
|
||||||
else:
|
|
||||||
grad = inv_precond @ grad
|
|
||||||
grad = grad.view(transposed_size)
|
|
||||||
|
|
||||||
# Update parameter
|
|
||||||
p.data.add_(grad, alpha=-lr)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
|
||||||
# Add eps for numerical stability
|
|
||||||
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
|
||||||
|
|
||||||
# Compute matrix power
|
|
||||||
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
|
||||||
|
|
||||||
return inv_precond
|
|
||||||
|
|
||||||
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
|
||||||
# Compute matrix power using SVD
|
|
||||||
u, s, v = torch.svd(matrix)
|
|
||||||
s_pow = s.pow(power)
|
|
||||||
return u @ torch.diag(s_pow) @ v.t()
|
|
||||||
|
|
||||||
# bring your own function to create zero-filled subclass
|
|
||||||
@staticmethod
|
|
||||||
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# follow bitsandbytes, only quantize tensors >= 4096 values
|
|
||||||
# also wrap subclass in DTensor when needed
|
|
||||||
def _new_buffer(self, p: Tensor, signed: bool):
|
|
||||||
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
|
||||||
if isinstance(p, DTensor):
|
|
||||||
out = DTensor.from_local(
|
|
||||||
local_tensor=self._subclass_zeros(
|
|
||||||
p.to_local(), signed, self.block_size
|
|
||||||
),
|
|
||||||
device_mesh=p.device_mesh,
|
|
||||||
placements=p.placements,
|
|
||||||
run_check=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = self._subclass_zeros(p, signed, self.block_size)
|
|
||||||
else:
|
|
||||||
out = torch.zeros_like(p)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Shampoo8bit(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=256,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=8,
|
|
||||||
optimizer_state_class=OptimState8bit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Shampoo4bit(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=128,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=4,
|
|
||||||
optimizer_state_class=OptimState4bit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShampooFp8(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=256,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=8, # FP8 uses 8 bits
|
|
||||||
optimizer_state_class=OptimStateFp8,
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
"""
|
|
||||||
config validation tests for swiglu args
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_base_cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
return DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
|
||||||
"learning_rate": 0.000001,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseValidation:
|
|
||||||
"""
|
|
||||||
Base validation module to setup the log capture
|
|
||||||
"""
|
|
||||||
|
|
||||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, caplog):
|
|
||||||
self._caplog = caplog
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
|
||||||
class TestValidation(BaseValidation):
|
|
||||||
"""
|
|
||||||
Test the validation module for liger
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_deprecated_swiglu(self, minimal_cfg):
|
|
||||||
test_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"liger_swiglu": False,
|
|
||||||
}
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
updated_cfg = validate_config(test_cfg)
|
|
||||||
assert (
|
|
||||||
"The 'liger_swiglu' argument is deprecated"
|
|
||||||
in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
assert updated_cfg.liger_swiglu is None
|
|
||||||
assert updated_cfg.liger_glu_activations is False
|
|
||||||
|
|
||||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
|
||||||
test_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"liger_swiglu": False,
|
|
||||||
"liger_glu_activations": True,
|
|
||||||
}
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
|
||||||
):
|
|
||||||
validate_config(test_cfg)
|
|
||||||
@@ -306,10 +306,6 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
|
||||||
# make sure prepared_path is empty
|
|
||||||
shutil.rmtree(prepared_path, ignore_errors=True)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
|||||||
Reference in New Issue
Block a user