Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
4a0ab11fcf chore: lint 2025-01-13 14:05:56 -05:00
Wing Lian
73b6b0a580 chore: lint 2025-01-13 13:56:16 -05:00
Wing Lian
9db5072407 make sure to use tensorboard to capture loss for checks 2025-01-13 13:56:16 -05:00
Wing Lian
42d3e36a6f fix adapter model check 2025-01-13 13:56:15 -05:00
Wing Lian
b12d93bedf make sure to use the correct tokenizer 2025-01-13 13:56:15 -05:00
Wing Lian
08ec9c0e5b make sure to set tokenizer from l3 70b and save safetensors 2025-01-13 13:56:15 -05:00
Wing Lian
9abac55f92 lower lr 2025-01-13 13:56:15 -05:00
Wing Lian
800e7fa41e set lora_dropout explicitly 2025-01-13 13:56:15 -05:00
Wing Lian
5a1c1b82d4 make the kd e2e fit in vram for ci and add lora version 2025-01-13 13:56:15 -05:00
Wing Lian
efb3f70d38 rename test files so it gets picked up 2025-01-13 13:56:15 -05:00
7 changed files with 68 additions and 24 deletions

View File

@@ -207,7 +207,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
@@ -253,7 +253,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -13,7 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass

View File

@@ -3,7 +3,7 @@
import logging
import warnings
from pathlib import Path
from typing import Optional, Union
from typing import Union
import fire
import transformers

View File

@@ -63,7 +63,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,

View File

@@ -13,7 +13,6 @@ from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
@@ -76,7 +75,7 @@ class SchedulerMixin(Trainer):
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
@@ -162,7 +161,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
@@ -202,12 +201,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
and self.args.embedding_lr is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()

View File

@@ -3,7 +3,7 @@
import functools
import logging
from pathlib import Path
from typing import List, Tuple, Union, Optional
from typing import List, Optional, Tuple, Union
from datasets import (
Dataset,

View File

@@ -4,10 +4,10 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@@ -16,14 +16,15 @@ from axolotl.utils.dict import DictDefault
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "unsloth/Llama-3.2-1B",
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": False,
"torch_compile": True,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
@@ -44,15 +45,15 @@ def min_cfg(temp_dir):
},
],
"val_set_size": 0.0,
"sequence_len": 4096,
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0001,
"learning_rate": 0.00001,
"bf16": "auto",
"gradient_checkpointing": True,
"flash_attention": True,
@@ -62,6 +63,8 @@ def min_cfg(temp_dir):
},
"max_steps": 5,
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
}
@@ -70,6 +73,9 @@ class TestKnowledgeDistillation:
Test case for Knowledge Distillation
"""
# While this will run on torch 2.4.x without torch_compile enabled
# the VRAM requirement is higher than what is available in CI
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
@@ -78,8 +84,38 @@ class TestKnowledgeDistillation:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
)
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
cfg = DictDefault(
{
"load_in_8bit": load_in_8bit,
"torch_compile": False,
"adapter": "lora",
"peft_use_dora": True,
"lora_target_linear": True,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)