Compare commits

..

8 Commits

21 changed files with 154 additions and 169 deletions

View File

@@ -1,16 +0,0 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
review_status: true
collapse_walkthrough: true
poem: false
sequence_diagrams: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true

View File

@@ -87,6 +87,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -97,7 +98,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -106,13 +106,6 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -137,45 +130,3 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 2
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -22,7 +22,6 @@ coverage:
only_pulls: true
flags: null
paths: null
informational: true
patch:
default:
# basic

View File

@@ -15,7 +15,7 @@ huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.53.2
tokenizers>=0.21.1
accelerate==1.9.0
accelerate==1.8.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.19.1

View File

@@ -43,7 +43,7 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("base_model") or cfg.output_dir
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:

View File

@@ -2,6 +2,7 @@
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
@@ -40,10 +41,14 @@ class TokenizedChatDataset(Dataset):
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=process_count,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",

View File

@@ -148,7 +148,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
return ["dataset_num_proc", "max_length"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -1,5 +1,7 @@
"""Module containing Dataset functionality"""
import os
import torch
from datasets import Dataset, IterableDataset
@@ -44,6 +46,7 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -56,13 +59,13 @@ class TokenizedPromptDataset(Dataset):
):
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=self.process_count,
num_proc=num_proc,
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
num_proc=num_proc,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",

View File

@@ -41,13 +41,3 @@ class CutCrossEntropyArgs(BaseModel):
)
return data
@model_validator(mode="before")
@classmethod
def check_chunked_cross_entropy_not_set(cls, data):
if data.get("chunked_cross_entropy"):
raise ValueError(
"Cut Cross Entropy does not support chunked cross entropy. "
"Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy."
)
return data

View File

@@ -163,6 +163,15 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
skip_move_to_device = self._build_model()
# Check if the model is a GraniteConfig object
if hasattr(self, 'model') and self.model.__class__.__name__ == "GraniteConfig":
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
LOG.error("Please use a different model type or ensure the model is properly configured.")
LOG.error("Setting trust_remote_code=True might help if the model requires custom code.")
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type or set trust_remote_code=True.")
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -216,15 +225,27 @@ class ModelLoader:
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
# Skip if model doesn't have the necessary methods
if not hasattr(self.model, "get_input_embeddings"):
LOG.warning("Model does not have get_input_embeddings method, skipping token embedding resize")
return
# Check if get_input_embeddings returns None
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None:
LOG.warning("Model's get_input_embeddings returned None, skipping token embedding resize")
return
embeddings_len = (
math.ceil(len(self.tokenizer) / 32) * 32
if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer)
)
if hasattr(self.model, "get_input_embeddings") and (
self.model.get_input_embeddings().num_embeddings < embeddings_len
if hasattr(input_embeddings, "num_embeddings") and (
input_embeddings.num_embeddings < embeddings_len
or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
input_embeddings.num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
):
@@ -233,14 +254,24 @@ class ModelLoader:
self.model_config.model_type != "llava"
):
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
if hasattr(self.model, "resize_token_embeddings"):
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:
LOG.warning("Model does not have resize_token_embeddings method, skipping resize")
else:
self.model.tie_weights()
if hasattr(self.model, "tie_weights"):
self.model.tie_weights()
def _adjust_model_config(self):
# Skip if model doesn't have config attribute
if not hasattr(self.model, "config"):
LOG.warning("Model does not have config attribute, skipping model config adjustments")
return
# Handle max_position_embeddings
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
hasattr(self.model.config, "max_position_embeddings")
and self.model.config.max_position_embeddings
and self.cfg.sequence_len > self.model.config.max_position_embeddings
):
@@ -250,17 +281,17 @@ class ModelLoader:
)
self.model.config.max_position_embeddings = self.cfg.sequence_len
# Handle bos_token_id
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "bos_token_id")
hasattr(self.model.config, "bos_token_id")
and self.model.config.bos_token_id
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
):
self.model.config.bos_token_id = self.tokenizer.bos_token_id
# Handle eos_token_id
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "eos_token_id")
hasattr(self.model.config, "eos_token_id")
and self.model.config.eos_token_id
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
):
@@ -292,9 +323,12 @@ class ModelLoader:
if self.cfg.adapter in ["lora", "qlora"]:
needs_fa2_dtype = True
if self.cfg.gradient_checkpointing:
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
)
if hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
)
else:
LOG.warning("Model does not have gradient_checkpointing_enable method, skipping gradient checkpointing")
self._prepare_model_for_quantization()
@@ -371,11 +405,14 @@ class ModelLoader:
self.model.is_parallelizable = True
self.model.model_parallel = True
if not any(
param.requires_grad
for _, param in self.model.named_parameters(recurse=True)
):
LOG.warning("There are no parameters that require gradient updates")
if hasattr(self.model, "named_parameters"):
if not any(
param.requires_grad
for _, param in self.model.named_parameters(recurse=True)
):
LOG.warning("There are no parameters that require gradient updates")
else:
LOG.warning("Model does not have named_parameters attribute, skipping gradient check")
if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
@@ -383,7 +420,10 @@ class ModelLoader:
self.model = BetterTransformer.transform(self.model)
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
if hasattr(self.model, "device"):
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
else:
LOG.warning("Model does not have device attribute, skipping memory usage logging")
for _ in range(3):
gc.collect()
@@ -700,6 +740,10 @@ class ModelLoader:
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.model_type == "GraniteSpeechConfig" and not hasattr(self.model_config, 'vocab_size'):
# Set vocab_size from tokenizer or use a reasonable default
self.model_config.vocab_size = getattr(self.model_config, 'vocab_size', 50257)
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
@@ -707,7 +751,21 @@ class ModelLoader:
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.model_type == "GraniteSpeechConfig":
# Use the actual model class for Granite Speech
self.model = transformers.GraniteSpeechForCausalLM.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if not hasattr(self.model_config, 'vocab_size'):
LOG.warning("Model config does not have vocab_size attribute, setting to 50257")
self.model_config.vocab_size = 50257
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
@@ -791,13 +849,19 @@ class ModelLoader:
dest = {"dtype": dist_dtype}
if self.cfg.lora_on_cpu:
dest["device"] = "cpu"
# Check if the model has named_modules attribute
if not hasattr(self.model, "named_modules"):
LOG.warning("Model does not have named_modules attribute, skipping embedding dtype conversion")
return
for name, module in self.model.named_modules():
if "norm" in name:
module.to(dist_dtype)
if before_kbit_train_or_finetune:
if name.endswith(".gate"):
module.to(dist_dtype)
if self.model_config.model_type == "btlm":
if self.model_config.model_type == "btlm" and "lm_head" in name:
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):

View File

@@ -188,8 +188,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
# the following check is for Qwen1 base models
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:

View File

@@ -113,7 +113,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
"propagate": False,
},
},

View File

@@ -151,11 +151,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return MllamaTextSelfAttention
if model_type == "llama4":
from transformers.models.llama4.modeling_llama4 import Llama4TextAttention
return Llama4TextAttention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"

View File

@@ -80,7 +80,15 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
# Check if model is actually a GraniteConfig object
if model.__class__.__name__ == "GraniteConfig":
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
LOG.error("Please use a different model type or ensure the model is properly configured.")
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type.")
if hasattr(model, "generation_config") and model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
@@ -90,7 +98,10 @@ def setup_model_and_tokenizer(
any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters
):
model.enable_input_require_grads()
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
LOG.warning("Model does not have enable_input_require_grads method, skipping")
return model, tokenizer, peft_config, processor
@@ -246,9 +257,12 @@ def save_trained_model(
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if hasattr(model, "named_modules"):
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
else:
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
# handle QAT
if cfg.qat:
@@ -308,11 +322,17 @@ def save_trained_model(
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
if hasattr(trainer.model, "save_pretrained"):
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
else:
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(model, "save_pretrained"):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
@@ -398,7 +418,10 @@ def save_initial_configs(
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
if hasattr(model.config, "save_pretrained"):
model.config.save_pretrained(str(output_dir))
else:
LOG.warning("Model config does not have save_pretrained method, skipping config save")
if processor:
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
@@ -461,9 +484,12 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
if hasattr(model, "save_pretrained"):
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[

View File

@@ -798,7 +798,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if state.is_world_process_zero:
if is_main_process():
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
with NamedTemporaryFile(

View File

@@ -148,6 +148,8 @@ def normalize_config(cfg):
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
)
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model

View File

@@ -410,8 +410,9 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
@@ -422,20 +423,10 @@ def save_preprocessed_dataset(
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "
@@ -469,13 +460,13 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=True,
main_process_only=False,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=True,
main_process_only=False,
)
return None

View File

@@ -20,7 +20,6 @@ from torchao.quantization.quant_api import (
UIntXWeightOnlyConfig,
_is_linear,
)
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchIntDType
@@ -150,9 +149,7 @@ def quantize_model_for_ptq(
group_size=group_size,
)
quantize_(model, linear_ptq_config)
quantization_config = TorchAoConfig(linear_ptq_config)
if quantize_embedding:
quantization_config.include_input_output_embeddings = True
embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=None,
@@ -163,7 +160,6 @@ def quantize_model_for_ptq(
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
model.config.quantization_config = quantization_config
def convert_qat_model_for_ptq(

View File

@@ -193,12 +193,6 @@ class AxolotlInputConfig(
json_schema_extra={"description": "Index of shard to use for whole dataset"},
)
skip_prepare_dataset: bool | None = False
num_dataset_shards_to_save: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to save the prepared dataset"
},
)
pretraining_dataset: (
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
@@ -209,12 +203,11 @@ class AxolotlInputConfig(
},
)
dataset_processes: int | None = Field(
default=None,
default=min(
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
), # type: ignore[type-var]
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
},
)
dataset_exact_deduplication: bool | None = Field(
@@ -1206,16 +1199,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["dataloader_prefetch_factor"] = 256
return data
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
data["dataset_processes"] = int(axolotl_dataset_processes)
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
data["dataset_processes"] = int(runpod_cpu_count)
else:
data["dataset_processes"] = os.cpu_count()
return data