Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
f8bb4185bc skip s2 attention test due to timeout 2024-04-08 18:33:33 -04:00
11 changed files with 55 additions and 170 deletions

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false
pad_to_sequence_len: false pad_to_sequence_len: false
adapter: qlora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -24,7 +24,6 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -63,20 +62,6 @@ def print_axolotl_text_art(suffix=None):
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]): def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file # Check if the config is a valid HTTPS URL to a .yml or .yaml file

View File

@@ -8,7 +8,6 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -28,26 +27,21 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False, flash_attention=False,
**kwargs, **kwargs,
) )
cfg = modify_cfg_for_merge(parsed_cfg)
do_merge_lora(cfg=cfg, cli_args=parsed_cli_args) if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError( raise ValueError(
f"Target directory for merge: `{cfg.lora_model_dir}` does not exist." f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
) )
cfg.load_in_4bit = False parsed_cfg.load_in_4bit = False
cfg.load_in_8bit = False parsed_cfg.load_in_8bit = False
cfg.flash_attention = False parsed_cfg.flash_attention = False
cfg.deepspeed = None parsed_cfg.deepspeed = None
cfg.fsdp = None parsed_cfg.fsdp = None
return cfg do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -36,7 +36,6 @@ from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -72,6 +71,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str): if isinstance(tag_names, str):
tag_names = [tag_names] tag_names = [tag_names]
@@ -940,16 +943,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks = [] callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0: if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory( LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb" trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))

View File

@@ -1,8 +0,0 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Dict, List
import evaluate import evaluate
import numpy as np import numpy as np
@@ -27,9 +27,7 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict, broadcast_dict,
@@ -542,7 +540,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback): class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation""" """Callback to log prediction values during each evaluation"""
@@ -599,13 +597,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
return ranges return ranges
def log_table_from_dataloader(name: str, table_dataloader): def log_table_from_dataloader(name: str, table_dataloader):
table_data: Dict[str, List[Any]] = { table = wandb.Table( # type: ignore[attr-defined]
"id": [], columns=[
"Prompt": [], "id",
"Correct Completion": [], "Prompt",
"Predicted Completion (model.generate)": [], "Correct Completion",
"Predicted Completion (trainer.prediction_step)": [], "Predicted Completion (model.generate)",
} "Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0 row_index = 0
for batch in tqdm(table_dataloader): for batch in tqdm(table_dataloader):
@@ -709,29 +709,16 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
) in zip( ) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts prompt_texts, completion_texts, predicted_texts, pred_step_texts
): ):
table_data["id"].append(row_index) table.add_data(
table_data["Prompt"].append(prompt_text) row_index,
table_data["Correct Completion"].append(completion_text) prompt_text,
table_data["Predicted Completion (model.generate)"].append( completion_text,
prediction_text prediction_text,
pred_step_text,
) )
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1 row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
tracking_uri = AxolotlInputConfig( wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
if is_main_process(): if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader) log_table_from_dataloader("Eval", eval_dataloader)
@@ -761,11 +748,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file: ) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name) copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name) wandb.save(temp_file.name)
LOG.info( LOG.info(
"The Axolotl config has been saved to the WandB run under files." "The Axolotl config has been saved to the WandB run under files."

View File

@@ -98,7 +98,6 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None ds_type: Optional[str] = None
train_on_split: Optional[str] = None train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None

View File

@@ -379,15 +379,14 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if isinstance(ds, DatasetDict): if config_dataset.split and config_dataset.split in ds:
if config_dataset.split and config_dataset.split in ds: ds = ds[config_dataset.split]
ds = ds[config_dataset.split] elif split in ds:
elif split in ds: ds = ds[split]
ds = ds[split] elif isinstance(ds, DatasetDict):
else: raise ValueError(
raise ValueError( f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" )
)
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if config_dataset.shards:

View File

@@ -198,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values
) )
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update: if update:
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
@@ -212,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum() .sum()
) )
LOG.debug( LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens:_}`", f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True, main_process_only=True,
) )
if update: if update:
@@ -239,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True, main_process_only=True,
) )
else: else:

View File

@@ -7,6 +7,8 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("Skipping test due to timeout.")
class TestLlamaShiftedSparseAttention(unittest.TestCase): class TestLlamaShiftedSparseAttention(unittest.TestCase):
""" """
Test case for Llama models using S2 Attn Test case for Llama models using S2 Attn

View File

@@ -1,16 +1,13 @@
""" """
E2E tests for lora llama E2E tests for lora llama
""" """
import json
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -42,6 +39,11 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -55,7 +57,6 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 10,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -64,67 +65,3 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"