Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
744f7082f5 fix for fsdp for models that aren't qwen2 or jamba 2024-04-05 17:02:54 -07:00
9 changed files with 50 additions and 107 deletions

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false
pad_to_sequence_len: false pad_to_sequence_len: false
adapter: qlora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -24,7 +24,6 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -63,20 +62,6 @@ def print_axolotl_text_art(suffix=None):
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]): def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file # Check if the config is a valid HTTPS URL to a .yml or .yaml file

View File

@@ -23,7 +23,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
PreTrainedModel,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
@@ -36,7 +35,6 @@ from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -72,6 +70,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str): if isinstance(tag_names, str):
tag_names = [tag_names] tag_names = [tag_names]
@@ -800,15 +802,6 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
@@ -940,16 +933,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks = [] callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0: if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory( LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb" trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))
@@ -1058,9 +1042,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.save_only_model is not None:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_efficiency" "sample_packing_efficiency"

View File

@@ -1,8 +0,0 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Dict, List
import evaluate import evaluate
import numpy as np import numpy as np
@@ -27,9 +27,7 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict, broadcast_dict,
@@ -542,7 +540,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback): class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation""" """Callback to log prediction values during each evaluation"""
@@ -599,13 +597,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
return ranges return ranges
def log_table_from_dataloader(name: str, table_dataloader): def log_table_from_dataloader(name: str, table_dataloader):
table_data: Dict[str, List[Any]] = { table = wandb.Table( # type: ignore[attr-defined]
"id": [], columns=[
"Prompt": [], "id",
"Correct Completion": [], "Prompt",
"Predicted Completion (model.generate)": [], "Correct Completion",
"Predicted Completion (trainer.prediction_step)": [], "Predicted Completion (model.generate)",
} "Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0 row_index = 0
for batch in tqdm(table_dataloader): for batch in tqdm(table_dataloader):
@@ -709,29 +709,16 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
) in zip( ) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts prompt_texts, completion_texts, predicted_texts, pred_step_texts
): ):
table_data["id"].append(row_index) table.add_data(
table_data["Prompt"].append(prompt_text) row_index,
table_data["Correct Completion"].append(completion_text) prompt_text,
table_data["Predicted Completion (model.generate)"].append( completion_text,
prediction_text prediction_text,
pred_step_text,
) )
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1 row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
tracking_uri = AxolotlInputConfig( wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
if is_main_process(): if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader) log_table_from_dataloader("Eval", eval_dataloader)
@@ -761,11 +748,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file: ) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name) copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name) wandb.save(temp_file.name)
LOG.info( LOG.info(
"The Axolotl config has been saved to the WandB run under files." "The Axolotl config has been saved to the WandB run under files."

View File

@@ -98,7 +98,6 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None ds_type: Optional[str] = None
train_on_split: Optional[str] = None train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
@@ -243,6 +242,17 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self return self
@model_validator(mode="before")
@classmethod
def validate_quantized_dora(cls, data):
if data.get("peft_use_dora") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"`peft_use_dora` is not currently compatible with quantized weights."
)
return data
class ReLoRAConfig(BaseModel): class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset""" """ReLoRA configuration subset"""
@@ -355,7 +365,6 @@ class ModelOutputConfig(BaseModel):
hub_model_id: Optional[str] = None hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None save_safetensors: Optional[bool] = None
save_only_model: Optional[bool] = None
class MLFlowConfig(BaseModel): class MLFlowConfig(BaseModel):
@@ -655,8 +664,8 @@ class AxolotlInputConfig(
and not data.get("flash_attention") and not data.get("flash_attention")
and not data.get("sdp_attention") and not data.get("sdp_attention")
): ):
LOG.warning( raise ValueError(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention." "sample_packing requires flash_attention or sdp_attention to be set to true"
) )
return data return data

View File

@@ -379,15 +379,14 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if isinstance(ds, DatasetDict): if config_dataset.split and config_dataset.split in ds:
if config_dataset.split and config_dataset.split in ds: ds = ds[config_dataset.split]
ds = ds[config_dataset.split] elif split in ds:
elif split in ds: ds = ds[split]
ds = ds[split] elif isinstance(ds, DatasetDict):
else: raise ValueError(
raise ValueError( f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" )
)
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if config_dataset.shards:

View File

@@ -459,7 +459,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16, "bnb_4bit_quant_storage": torch.bfloat16,
} }
if not cfg.deepspeed: if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"):
# for some reason, this causes the loss to be off by an order of magnitude # for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -902,12 +902,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
if rank == 0: if rank == 0:
try: model.print_trainable_parameters()
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif cfg.fsdp and cfg.adapter == "qlora": elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)

View File

@@ -198,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values
) )
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update: if update:
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
@@ -212,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum() .sum()
) )
LOG.debug( LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens:_}`", f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True, main_process_only=True,
) )
if update: if update:
@@ -239,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True, main_process_only=True,
) )
else: else: