Compare commits
10 Commits
58d9896777
...
4a0ab11fcf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a0ab11fcf | ||
|
|
73b6b0a580 | ||
|
|
9db5072407 | ||
|
|
42d3e36a6f | ||
|
|
b12d93bedf | ||
|
|
08ec9c0e5b | ||
|
|
9abac55f92 | ||
|
|
800e7fa41e | ||
|
|
5a1c1b82d4 | ||
|
|
efb3f70d38 |
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -207,7 +207,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
@@ -13,7 +13,12 @@ class PreprocessCliArgs:
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
|
||||
iterable: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
@@ -63,7 +63,11 @@ def load_datasets(
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
|
||||
preprocess_iterable = (
|
||||
hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg,
|
||||
|
||||
@@ -13,7 +13,6 @@ from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
@@ -76,7 +75,7 @@ class SchedulerMixin(Trainer):
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
@@ -162,7 +161,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
@@ -202,12 +201,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
|
||||
@@ -4,10 +4,10 @@ e2e tests for kd trainer support in Axolotl
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from e2e.utils import check_tensorboard
|
||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -16,14 +16,15 @@ from axolotl.utils.dict import DictDefault
|
||||
@pytest.fixture(name="kd_min_cfg")
|
||||
def min_cfg(temp_dir):
|
||||
return {
|
||||
"base_model": "unsloth/Llama-3.2-1B",
|
||||
"base_model": "osllmai-community/Llama-3.2-1B",
|
||||
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
"plugins": [
|
||||
"axolotl.integrations.kd.KDPlugin",
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
],
|
||||
"liger_rms_norm": True,
|
||||
"liger_glu_activation": True,
|
||||
"torch_compile": False,
|
||||
"torch_compile": True,
|
||||
"chat_template": "llama3",
|
||||
"kd_trainer": True,
|
||||
"kd_ce_alpha": 0.1,
|
||||
@@ -44,15 +45,15 @@ def min_cfg(temp_dir):
|
||||
},
|
||||
],
|
||||
"val_set_size": 0.0,
|
||||
"sequence_len": 4096,
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 1,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0001,
|
||||
"learning_rate": 0.00001,
|
||||
"bf16": "auto",
|
||||
"gradient_checkpointing": True,
|
||||
"flash_attention": True,
|
||||
@@ -62,6 +63,8 @@ def min_cfg(temp_dir):
|
||||
},
|
||||
"max_steps": 5,
|
||||
"output_dir": temp_dir,
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +73,9 @@ class TestKnowledgeDistillation:
|
||||
Test case for Knowledge Distillation
|
||||
"""
|
||||
|
||||
# While this will run on torch 2.4.x without torch_compile enabled
|
||||
# the VRAM requirement is higher than what is available in CI
|
||||
@require_torch_2_5_1
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -78,8 +84,38 @@ class TestKnowledgeDistillation:
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
[True, False],
|
||||
)
|
||||
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": load_in_8bit,
|
||||
"torch_compile": False,
|
||||
"adapter": "lora",
|
||||
"peft_use_dora": True,
|
||||
"lora_target_linear": True,
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
Reference in New Issue
Block a user