Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
17330c05a3 shampoo checkpoint save workaround 2024-09-23 15:21:00 -04:00
Wing Lian
992ea517b7 setup precision config for bf16 2024-09-18 12:01:38 -07:00
Wing Lian
beaee36191 ddp shampoo 2024-09-18 10:50:46 -07:00
Wing Lian
69a29382e1 fix casting of optim args 2024-09-18 10:48:15 -07:00
Wing Lian
84dad0bd12 ensure epsilon is cast to float 2024-09-18 10:42:26 -07:00
Wing Lian
05f61a0ea5 remove accidental duplidcated line 2024-09-18 10:41:03 -07:00
Wing Lian
5334d0fc01 fixes 2024-09-18 10:38:59 -07:00
Wing Lian
52e6249d2e additional grafting config types and basic example doc 2024-09-18 08:16:11 -07:00
Wing Lian
eb3eab3450 wip shampoo optim support 2024-09-18 08:10:52 -07:00
48 changed files with 450 additions and 1447 deletions

View File

@@ -28,13 +28,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -84,7 +84,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -26,7 +26,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -83,7 +83,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
@@ -91,7 +91,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
@@ -94,7 +94,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_third_party=wandb

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- Log results and optionally checkpoints to wandb or mlflow
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -515,22 +515,6 @@ wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:

View File

@@ -90,7 +90,6 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -268,18 +267,6 @@ mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to
output_dir: ./completed-model

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
We can check that the right tokens are ingored by comparing the labels
to each token:
```python

View File

@@ -1,28 +0,0 @@
# MultiModal / Vision Language Models (BETA)
### Supported Models
- Mllama, i.e. llama with vision models
### Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```

17
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,17 @@
# Optimizers
## Shampoo
```yaml
optimizer: shampoo
optim_shampoo_betas: [0.9, 0.999]
optim_args:
epsilon: 1e-12
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -1,63 +0,0 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,9 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.0
transformers==4.45.1
peft==0.12.0
transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992
tokenizers>=0.19.1
bitsandbytes==0.44.0
bitsandbytes==0.43.3
accelerate==0.34.2
datasets==2.21.0
deepspeed==0.14.4
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers==0.0.28.post1
xformers==0.0.27
optimum==1.16.2
hf_transfer
colorama
@@ -34,7 +34,8 @@ tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
liger-kernel==0.2.1
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
mamba-ssm==1.2.0.post1
@@ -46,9 +47,3 @@ gcsfs>=2024.5.0
trl==0.9.6
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

View File

@@ -49,17 +49,10 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
if (major, minor) >= (2, 3):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")

View File

@@ -30,8 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -41,7 +39,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -55,22 +53,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -83,13 +67,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
@@ -257,8 +234,7 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
@@ -266,13 +242,10 @@ def do_inference_gradio(
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -286,24 +259,7 @@ def do_inference_gradio(
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
@@ -326,7 +282,6 @@ def do_inference_gradio(
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
@@ -443,8 +398,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
@@ -454,12 +407,9 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
cfg, tokenizer
)
if cli_args.debug or cfg.debug:

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Tuple, Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -18,7 +20,6 @@ from axolotl.cli import (
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None:
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -16,12 +16,11 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
@@ -46,9 +45,10 @@ from trl import (
)
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -61,14 +61,12 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -252,10 +250,11 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
@dataclass
@@ -428,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"shampoo",
]
):
return super().create_optimizer()
@@ -462,14 +467,110 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
PrecisionConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
)
# parse args.optim_args
optim_args = {}
if self.args.optim_args:
for mapping in self.args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if "max_preconditioner_dim" in optim_args:
optim_args["max_preconditioner_dim"] = int(
optim_args["max_preconditioner_dim"]
)
if "precondition_frequency" in optim_args:
optim_args["precondition_frequency"] = int(
optim_args["precondition_frequency"]
)
if "use_decoupled_weight_decay" in optim_args:
optim_args["use_decoupled_weight_decay"] = bool(
optim_args["use_decoupled_weight_decay"]
)
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
):
self.args.optim_shampoo_grafting_config_kwargs[
"epsilon"
] = float(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
)
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1:
if self.args.fsdp and self.args.fsdp_config:
distributed_config = FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(
self.model_wrapped
)
)
else:
distributed_config = DDPShampooConfig(
communication_dtype=CommunicationDType.BF16,
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
precision_config = None
if self.args.bf16:
precision_config = PrecisionConfig(
computation_dtype=torch.bfloat16,
factor_matrix_dtype=torch.bfloat16,
inv_factor_matrix_dtype=torch.bfloat16,
filtered_grad_dtype=torch.bfloat16,
momentum_dtype=torch.bfloat16,
grafting_state_dtype=torch.bfloat16,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
precision_config=precision_config,
**optim_args,
)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -876,7 +977,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
try:
return super()._save_checkpoint(model, trial, metrics=metrics)
except NotImplementedError as exc:
LOG.warning(f"Failed to save checkpoint: {exc}")
return None
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -975,9 +1080,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
@@ -1049,11 +1154,10 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
@@ -1111,12 +1215,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -1185,11 +1283,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1435,14 +1528,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
@@ -1463,6 +1552,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
# shampoo optimizer config
if self.cfg.optim_shampoo_betas:
training_arguments_kwargs[
"optim_shampoo_betas"
] = self.cfg.optim_shampoo_betas
if self.cfg.optim_shampoo_grafting_config_type:
training_arguments_kwargs[
"optim_shampoo_grafting_config_type"
] = self.cfg.optim_shampoo_grafting_config_type
if self.cfg.optim_shampoo_grafting_config_kwargs:
training_arguments_kwargs[
"optim_shampoo_grafting_config_kwargs"
] = self.cfg.optim_shampoo_grafting_config_kwargs
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
@@ -1535,10 +1639,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1551,10 +1651,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1600,12 +1702,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
@@ -1685,12 +1781,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq
return collator(
self.tokenizer,

View File

@@ -159,29 +159,6 @@ class BasePlugin:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
@@ -404,17 +381,3 @@ class PluginManager:
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,229 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]

View File

@@ -10,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"mistral",
"mixtral",

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -16,7 +16,6 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math

View File

@@ -9,7 +9,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
@@ -24,8 +24,6 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None

View File

@@ -5,8 +5,6 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
@@ -22,7 +20,6 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
@@ -47,12 +44,11 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
@@ -65,28 +61,6 @@ class ChatTemplatePrompter(Prompter):
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -217,7 +191,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
@@ -236,21 +209,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
turns[:-1], add_generation_prompt=True
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
@@ -258,9 +220,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
else:
labels = input_ids
tokenized_prompt["labels"] = labels
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
LOG.info(self.roles_to_train)
LOG.info(self.train_on_eos)
LOG.info(self.prompter.message_field_training)
LOG.info(self.prompter.message_field_training_detail)
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
@@ -398,18 +368,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
@@ -419,7 +386,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = {

View File

@@ -24,7 +24,7 @@ from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
@@ -69,9 +69,6 @@ def train(
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
@@ -99,9 +96,7 @@ def train(
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None
@@ -127,7 +122,6 @@ def train(
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)

View File

@@ -1,12 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib.util
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references],
predictions=predictions,
)
scores["eval_" + metric_name] = score
scores[metric_name] = score
return scores
def predict_with_generate():
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -1,43 +0,0 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

File diff suppressed because one or more lines are too long

View File

@@ -1,14 +1,17 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -180,6 +183,34 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""

View File

@@ -1,10 +0,0 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -1,4 +0,0 @@
"""
basic shared collator constants
"""
IGNORE_INDEX = -100

View File

@@ -1,38 +0,0 @@
"""
collators for Mamba
"""
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
from axolotl.utils.collators.core import IGNORE_INDEX
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -1,77 +0,0 @@
"""
Collators for multi-modal chat messages and packing
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
"""
Collator for multi-modal chat messages
"""
tokenizer: PreTrainedTokenizerBase
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
def __post_init__(self):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [example["images"] for example in examples]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch

View File

@@ -1,93 +0,0 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.utils.comet_")
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}
def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"
return "false"
if isinstance(python_value, int):
return str(python_value)
if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))
return python_value
def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first
for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")
if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value
if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)
if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue
final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value
# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True

View File

@@ -121,36 +121,15 @@ def normalize_config(cfg):
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
"pixtral",
]
)
or cfg.is_multimodal
)
if cfg.is_multimodal:
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
)
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())

View File

@@ -188,7 +188,6 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
@@ -229,12 +228,11 @@ class LoraConfig(BaseModel):
lora_r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
lora_target_modules: Optional[Union[str, List[str]]] = None
lora_target_modules: Optional[List[str]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = 0.0
peft_layers_to_transform: Optional[List[int]] = None
peft_layers_pattern: Optional[List[str]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
@@ -300,13 +298,6 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@classmethod
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -330,9 +321,6 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, metadata={"help": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@@ -384,6 +372,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
],
]
] = OptimizerNames.ADAMW_HF.value
@@ -396,6 +385,12 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
@@ -489,19 +484,6 @@ class WandbConfig(BaseModel):
return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
@@ -522,7 +504,6 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
@@ -549,7 +530,6 @@ class AxolotlInputConfig(
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
@@ -980,26 +960,6 @@ class AxolotlInputConfig(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data
@model_validator(mode="before")
@@ -1037,18 +997,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_mm_prepare(cls, data):
if data.get("skip_prepare_dataset"):
if data.get("remove_unused_columns") is None:
LOG.info(
"setting `remove_unused_columns: false` for skip_prepare_dataset"
)
data["remove_unused_columns"] = False
return data
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
@@ -1076,20 +1024,12 @@ class AxolotlInputConfig(
return neftune_noise_alpha
@model_validator(mode="after")
def check_rl_beta(self):
def check(self):
if self.dpo_beta and not self.rl_beta:
self.rl_beta = self.dpo_beta
del self.dpo_beta
return self
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
@@ -1104,15 +1044,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_peft_layers_pattern(cls, data):
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
raise ValueError(
"peft_layers_pattern requires peft_layers_to_transform to be set"
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (

View File

@@ -51,31 +51,20 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer, processor=None):
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
@@ -134,7 +123,6 @@ def load_tokenized_prepared_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -192,7 +180,6 @@ def load_tokenized_prepared_datasets(
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
and not cfg.skip_prepare_dataset
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
@@ -242,7 +229,6 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -347,7 +333,6 @@ def load_tokenized_prepared_datasets(
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -382,7 +367,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -392,7 +376,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
@@ -437,19 +420,15 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
processor=processor,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
if len(datasets) == 1:
dataset = datasets[0]
else:
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
if cfg.shuffle_merged_datasets:
@@ -458,10 +437,9 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
@@ -500,14 +478,9 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer,
cfg,
default_dataset_prepared_path,
split=split,
processor=processor,
tokenizer, cfg, default_dataset_prepared_path, split=split
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -573,7 +546,6 @@ def get_dataset_wrapper(
d_base_type,
dataset,
d_prompt_style=None,
processor=None,
):
dataset_wrapper = None
dataset_prompter = None
@@ -606,11 +578,7 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,

View File

@@ -28,17 +28,12 @@ from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
LlavaForConditionalGeneration,
MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -85,9 +80,6 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
model_config = model_config.text_config
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
@@ -307,31 +299,11 @@ def load_tokenizer(cfg):
return tokenizer
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
return processor
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
processor: ProcessorMixin = None, # pylint: disable=unused-argument
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
@@ -347,23 +319,12 @@ def load_model(
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(cfg)
if cfg.is_multimodal:
text_model_config = model_config.text_config
else:
text_model_config = model_config
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
if cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -500,19 +461,6 @@ def load_model(
max_memory = cfg.max_memory
device_map = cfg.device_map
AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
if cfg.is_multimodal:
if model_config.model_type == "llava":
AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif model_config.model_type == "mllama":
AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
@@ -530,7 +478,7 @@ def load_model(
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = AutoModelLoader.from_config(
model_canvas = AutoModelForCausalLM.from_config(
model_config, trust_remote_code=cfg.trust_remote_code or False
)
model_canvas.tie_weights()
@@ -685,8 +633,6 @@ def load_model(
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = load_sharded_model_quant(
base_model,
model_config,
@@ -705,9 +651,7 @@ def load_model(
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -746,17 +690,13 @@ def load_model(
and not cfg.trust_remote_code
):
if cfg.gptq:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
@@ -767,23 +707,21 @@ def load_model(
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(text_model_config, "max_seq_len")
and text_model_config.max_seq_len
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
text_model_config.max_seq_len = cfg.sequence_len
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(text_model_config, "max_sequence_length")
and text_model_config.max_sequence_length
and cfg.sequence_len > text_model_config.max_sequence_length
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
text_model_config.max_sequence_length = cfg.sequence_len
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
@@ -796,9 +734,7 @@ def load_model(
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
@@ -1080,17 +1016,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
from peft import LoraConfig, get_peft_model
lora_target_modules = cfg.lora_target_modules or []
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
@@ -1109,7 +1040,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,

View File

@@ -306,7 +306,7 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
@@ -319,11 +319,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
skip_estimates = cfg.model_config_type == "mamba"
if (
not skip_estimates
and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset
):
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -482,15 +478,13 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
roles={
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=chat_templates("phi_35"),
chat_templates("phi_35"),
message_field_role="role",
message_field_content="content",
roles={
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
@@ -227,11 +227,8 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -280,11 +277,8 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -333,11 +327,8 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",

View File

@@ -34,9 +34,7 @@ class TestChatTemplateConfigurations:
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
@@ -79,9 +77,7 @@ class TestChatTemplateConfigurations:
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -122,9 +118,7 @@ class TestChatTemplateConfigurations:
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -150,9 +144,7 @@ class TestChatTemplateConfigurations:
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
@@ -183,9 +175,7 @@ class TestChatTemplateConfigurations:
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -204,9 +194,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -231,9 +219,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -281,9 +267,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -314,9 +298,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
),
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -342,9 +324,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
drop_system_message=True,
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -370,9 +350,7 @@ class TestChatTemplateConfigurations:
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
roles=custom_roles,
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -424,7 +402,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),

View File

@@ -267,74 +267,6 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -9,7 +9,6 @@ from typing import Optional
import pytest
from pydantic import ValidationError
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
@@ -1330,105 +1329,3 @@ class TestValidationWandb(BaseValidation):
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
class TestValidationComet(BaseValidation):
"""
Validation test for comet
"""
def test_comet_sets_env(self, minimal_cfg):
from axolotl.utils.comet_ import setup_comet_env_vars
comet_config = {
"comet_api_key": "foo",
"comet_workspace": "some_workspace",
"comet_project_name": "some_project",
"comet_experiment_key": "some_experiment_key",
"comet_mode": "get_or_create",
"comet_online": False,
"comet_experiment_config": {
"auto_histogram_activation_logging": False,
"auto_histogram_epoch_rate": 2,
"auto_histogram_gradient_logging": True,
"auto_histogram_tensorboard_logging": False,
"auto_histogram_weight_logging": True,
"auto_log_co2": False,
"auto_metric_logging": True,
"auto_metric_step_rate": 15,
"auto_output_logging": False,
"auto_param_logging": True,
"comet_disabled": False,
"display_summary_level": 2,
"distributed_node_identifier": "some_distributed_node_identifier",
"log_code": True,
"log_env_cpu": False,
"log_env_details": True,
"log_env_disk": False,
"log_env_gpu": True,
"log_env_host": False,
"log_env_network": True,
"log_git_metadata": False,
"log_git_patch": True,
"log_graph": False,
"name": "some_name",
"offline_directory": "some_offline_directory",
"parse_args": True,
"tags": ["tag1", "tag2"],
},
}
cfg = DictDefault(comet_config) | minimal_cfg
new_cfg = validate_config(cfg)
setup_comet_env_vars(new_cfg)
comet_env = {
key: value for key, value in os.environ.items() if key.startswith("COMET_")
}
assert (
len(comet_env)
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
)
assert comet_env == {
"COMET_API_KEY": "foo",
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
"COMET_AUTO_LOG_CO2": "false",
"COMET_AUTO_LOG_CODE": "true",
"COMET_AUTO_LOG_DISABLE": "false",
"COMET_AUTO_LOG_ENV_CPU": "false",
"COMET_AUTO_LOG_ENV_DETAILS": "true",
"COMET_AUTO_LOG_ENV_DISK": "false",
"COMET_AUTO_LOG_ENV_GPU": "true",
"COMET_AUTO_LOG_ENV_HOST": "false",
"COMET_AUTO_LOG_ENV_NETWORK": "true",
"COMET_AUTO_LOG_GIT_METADATA": "false",
"COMET_AUTO_LOG_GIT_PATCH": "true",
"COMET_AUTO_LOG_GRAPH": "false",
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
"COMET_AUTO_LOG_METRICS": "true",
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
"COMET_AUTO_LOG_PARAMETERS": "true",
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
"COMET_EXPERIMENT_KEY": "some_experiment_key",
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
"COMET_PROJECT_NAME": "some_project",
"COMET_START_EXPERIMENT_NAME": "some_name",
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
"COMET_START_MODE": "get_or_create",
"COMET_START_ONLINE": "false",
"COMET_WORKSPACE": "some_workspace",
}
for key in comet_env.keys():
os.environ.pop(key, None)