Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
4c92b51cd5 fix the torch dtype check 2024-04-11 08:56:46 -04:00
Wing Lian
5767eea874 add tests for merging lora and validating the dtype 2024-04-10 13:00:37 -04:00
Thomas Capelle
5ed29393e3 Update SaveAxolotlConfigtoWandBCallback to use artifact instead of save (#1483)
* deprecated wandb.save

* also use wandb.save for axolotl yaml

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 18:58:38 -04:00
Wing Lian
da9b1a3196 use locale agnostic seperator to make large nums easier to read (#1503) 2024-04-09 17:28:43 -04:00
DavidFarago
057fa44191 WIP: Support table logging for mlflow, too (#1506)
* WIP: Support table logging for mlflow, too

Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.

In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".

See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505

* chore: lint

* add additional clause for mlflow as it's optional

* Fix circular imports

---------

Co-authored-by: Dave Farago <dfarago@innoopract.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 17:28:27 -04:00
Scott Fleming
8fa0785f74 Correctly handle splits for datasets.arrow_dataset.Dataset objects (#1504)
* Correctly handle splits for datasets.arrow_dataset.Dataset objects

The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang.

See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context.

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 16:40:26 -04:00
Wing Lian
4313b1a6a0 Print versions (#1496)
* print out dependency versions for easier debugging

* improve readability
2024-04-09 11:05:15 -04:00
Maziyar Panahi
7f17eff81a Fix the wrong adapter in qwen2-moe-qlora example (#1501) [skip ci]
It should be `qlora` instead of `lora`
2024-04-09 10:57:24 -04:00
Wing Lian
ff01c45127 add field to sft dataset pydantic for completion support (#1497) 2024-04-08 21:37:54 -04:00
11 changed files with 170 additions and 55 deletions

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false
pad_to_sequence_len: false pad_to_sequence_len: false
adapter: lora adapter: qlora
lora_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -24,6 +24,7 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -62,6 +63,20 @@ def print_axolotl_text_art(suffix=None):
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]): def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file # Check if the config is a valid HTTPS URL to a .yml or .yaml file

View File

@@ -8,6 +8,7 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -27,21 +28,26 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False, flash_attention=False,
**kwargs, **kwargs,
) )
cfg = modify_cfg_for_merge(parsed_cfg)
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir: do_merge_lora(cfg=cfg, cli_args=parsed_cli_args)
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError( raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." f"Target directory for merge: `{cfg.lora_model_dir}` does not exist."
) )
parsed_cfg.load_in_4bit = False cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False cfg.load_in_8bit = False
parsed_cfg.flash_attention = False cfg.flash_attention = False
parsed_cfg.deepspeed = None cfg.deepspeed = None
parsed_cfg.fsdp = None cfg.fsdp = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) return cfg
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -36,6 +36,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -71,10 +72,6 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str): if isinstance(tag_names, str):
tag_names = [tag_names] tag_names = [tag_names]
@@ -943,7 +940,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks = [] callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0: if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory( LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))

View File

@@ -0,0 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
import evaluate import evaluate
import numpy as np import numpy as np
@@ -27,7 +27,9 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict, broadcast_dict,
@@ -540,7 +542,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer): def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
class LogPredictionCallback(TrainerCallback): class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation""" """Callback to log prediction values during each evaluation"""
@@ -597,15 +599,13 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
return ranges return ranges
def log_table_from_dataloader(name: str, table_dataloader): def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined] table_data: Dict[str, List[Any]] = {
columns=[ "id": [],
"id", "Prompt": [],
"Prompt", "Correct Completion": [],
"Correct Completion", "Predicted Completion (model.generate)": [],
"Predicted Completion (model.generate)", "Predicted Completion (trainer.prediction_step)": [],
"Predicted Completion (trainer.prediction_step)", }
]
)
row_index = 0 row_index = 0
for batch in tqdm(table_dataloader): for batch in tqdm(table_dataloader):
@@ -709,16 +709,29 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
) in zip( ) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts prompt_texts, completion_texts, predicted_texts, pred_step_texts
): ):
table.add_data( table_data["id"].append(row_index)
row_index, table_data["Prompt"].append(prompt_text)
prompt_text, table_data["Correct Completion"].append(completion_text)
completion_text, table_data["Predicted Completion (model.generate)"].append(
prediction_text, prediction_text
pred_step_text,
) )
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1 row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined] tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
if is_main_process(): if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader) log_table_from_dataloader("Eval", eval_dataloader)
@@ -748,6 +761,11 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file: ) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name) copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name) wandb.save(temp_file.name)
LOG.info( LOG.info(
"The Axolotl config has been saved to the WandB run under files." "The Axolotl config has been saved to the WandB run under files."

View File

@@ -98,6 +98,7 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None ds_type: Optional[str] = None
train_on_split: Optional[str] = None train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None

View File

@@ -379,14 +379,15 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if config_dataset.split and config_dataset.split in ds: if isinstance(ds, DatasetDict):
ds = ds[config_dataset.split] if config_dataset.split and config_dataset.split in ds:
elif split in ds: ds = ds[config_dataset.split]
ds = ds[split] elif split in ds:
elif isinstance(ds, DatasetDict): ds = ds[split]
raise ValueError( else:
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" raise ValueError(
) f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if config_dataset.shards:

View File

@@ -198,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values
) )
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True) LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
if update: if update:
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
@@ -212,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum() .sum()
) )
LOG.debug( LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`", f"`total_supervised_tokens: {total_supervised_tokens:_}`",
main_process_only=True, main_process_only=True,
) )
if update: if update:
@@ -239,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}", f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
main_process_only=True, main_process_only=True,
) )
else: else:

View File

@@ -7,8 +7,6 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
@@ -21,7 +19,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("Skipping test due to timeout.")
class TestLlamaShiftedSparseAttention(unittest.TestCase): class TestLlamaShiftedSparseAttention(unittest.TestCase):
""" """
Test case for Llama models using S2 Attn Test case for Llama models using S2 Attn

View File

@@ -1,13 +1,16 @@
""" """
E2E tests for lora llama E2E tests for lora llama
""" """
import json
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli import load_datasets from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -39,11 +42,6 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -57,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 10,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -65,3 +64,67 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"