Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
05f7034288 use deterministic seed for random LISA layers 2024-04-04 18:16:55 -07:00
14 changed files with 76 additions and 212 deletions

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16

View File

@@ -24,7 +24,6 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -63,20 +62,6 @@ def print_axolotl_text_art(suffix=None):
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file

View File

@@ -8,7 +8,6 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -28,26 +27,19 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False,
**kwargs,
)
cfg = modify_cfg_for_merge(parsed_cfg)
do_merge_lora(cfg=cfg, cli_args=parsed_cli_args)
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{cfg.lora_model_dir}` does not exist."
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
)
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
return cfg
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":

View File

@@ -23,7 +23,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -36,7 +35,6 @@ from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -72,6 +70,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -800,15 +802,6 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
class TrainerBuilderBase(abc.ABC):
"""
@@ -940,16 +933,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))

View File

@@ -1,8 +0,0 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
@@ -27,9 +27,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
@@ -542,7 +540,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -599,13 +597,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
return ranges
def log_table_from_dataloader(name: str, table_dataloader):
table_data: Dict[str, List[Any]] = {
"id": [],
"Prompt": [],
"Correct Completion": [],
"Predicted Completion (model.generate)": [],
"Predicted Completion (trainer.prediction_step)": [],
}
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0
for batch in tqdm(table_dataloader):
@@ -709,29 +709,16 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(
prediction_text
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
@@ -761,11 +748,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."

View File

@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
for param in layer.parameters():
param.requires_grad = False
def on_train_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
self.switch_active_layers(state)
def on_step_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1:
self.switch_active_layers()
if state.global_step % self.step_interval == 0:
self.switch_active_layers(state)
def switch_active_layers(self):
def switch_active_layers(self, state):
# First, disable gradients for all layers
self.freeze_all_layers()
deterministic_seed = state.global_step
np.random.seed(deterministic_seed)
# Randomly select n_layers to activate
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False
range(self.total_layers),
self.n_layers,
replace=False,
)
LOG.info(
f"Activating layers at indices: {self.active_layers_indices} for the next steps."

View File

@@ -23,7 +23,6 @@ def chat_templates(user_choice: str):
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
}
if user_choice in templates:

View File

@@ -1,7 +1,6 @@
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
@@ -98,7 +97,6 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -142,7 +140,6 @@ class ChatTemplate(str, Enum):
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -243,6 +240,17 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@model_validator(mode="before")
@classmethod
def validate_quantized_dora(cls, data):
if data.get("peft_use_dora") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"`peft_use_dora` is not currently compatible with quantized weights."
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -646,20 +654,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):

View File

@@ -379,15 +379,14 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if isinstance(ds, DatasetDict):
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
else:
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data
if config_dataset.shards:

View File

@@ -902,12 +902,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
model = get_peft_model(model, lora_config)
if rank == 0:
try:
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)

View File

@@ -198,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update:
cfg.total_num_tokens = total_num_tokens
@@ -212,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
if update:
@@ -239,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True,
)
else:

View File

@@ -1,16 +1,13 @@
"""
E2E tests for lora llama
"""
import json
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
@@ -42,6 +39,11 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -55,7 +57,6 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
}
)
normalize_config(cfg)
@@ -64,67 +65,3 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"

View File

@@ -600,7 +600,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
@@ -902,7 +901,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -918,7 +916,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg
@@ -931,7 +928,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": False,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -945,7 +941,6 @@ class TestValidation(BaseValidation):
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg