Compare commits

..

20 Commits

Author SHA1 Message Date
sunny
bbf5158e9c test 2024-11-07 11:06:28 -05:00
sunny
ec70046a2b test 2024-11-07 11:04:33 -05:00
sunny
7fed41550e test 2024-11-07 11:02:54 -05:00
sunny
da3a941bc3 test 2024-11-07 11:00:51 -05:00
sunny
ad3c179a5a test 2024-11-07 10:59:29 -05:00
sunny
15e26b14eb test 2024-11-07 10:54:48 -05:00
sunny
33bbe9b222 test 2024-11-07 10:52:52 -05:00
sunny
1fddf45958 test 2024-11-07 10:46:47 -05:00
Wing Lian
e42e319446 make sure prepared path is empty for test 2024-11-06 10:20:51 -05:00
Wing Lian
613f238e56 use kwargs to support patch release 2024-11-06 09:43:35 -05:00
Wing Lian
6b617a4fd5 also upgrade accelerate 2024-11-06 08:59:52 -05:00
Wing Lian
6ac10de9ef upgrade liger and transformers 2024-11-06 08:53:03 -05:00
Wing Lian
1b8d439441 add test case 2024-11-05 09:23:08 +07:00
Wing Lian
1ed351781a chore: lint 2024-11-05 09:23:08 +07:00
Wing Lian
c2a48c3a1e add logging 2024-11-05 09:23:08 +07:00
Wing Lian
415399b565 Update README.md
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
67c04133f2 Update src/axolotl/integrations/liger/args.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
4911d0952f skip duplicate code check 2024-11-05 09:23:08 +07:00
Wing Lian
1d7ab52161 update docs and example 2024-11-05 09:23:08 +07:00
Wing Lian
fcdc6fee8b upgrade liger to 0.3.1 2024-11-05 09:23:08 +07:00
10 changed files with 226 additions and 16 deletions

View File

@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template)
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt

View File

@@ -9,7 +9,7 @@ strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2

View File

@@ -9,21 +9,27 @@ liger_fused_linear_cross_entropy: true
strict: false
chat_template: llama3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0
val_set_size: 0.02
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project: check_liger_hf_GA_llama_fix-3
wandb_entity: axolotl-ai
wandb_project:
wandb_entity:
wandb_watch:
wandb_name: pr/fix333-tr4.46.1
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.1
transformers==4.46.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
accelerate==1.1.0
datasets==3.0.1
deepspeed==0.15.3
pydantic==2.6.3
@@ -33,8 +33,8 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=3.1.0
liger-kernel==0.3.1
triton>=2.3.0
liger-kernel==0.4.0
mamba-ssm==1.2.0.post1

View File

@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):

76
test.yml Normal file
View File

@@ -0,0 +1,76 @@
base_model: JackFram/llama-68m
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.5
output_dir: ./outputs/out
sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>

View File

@@ -1,7 +1,6 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
@@ -64,6 +63,51 @@ class LigerIntegrationTestCase(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
def test_llama_wo_flce2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
def test_llama_w_flce(self, temp_dir):
cfg = DictDefault(

View File

View File

@@ -0,0 +1,80 @@
"""
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
}
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
validate_config(test_cfg)

View File

@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
# make sure prepared_path is empty
shutil.rmtree(prepared_path, ignore_errors=True)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",