Compare commits
11 Commits
fix/granit
...
quantize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a51852af1 | ||
|
|
170322a1f0 | ||
|
|
5f5ae76213 | ||
|
|
a798975b7c | ||
|
|
d23f972602 | ||
|
|
8e41317250 | ||
|
|
9f2bb188a4 | ||
|
|
9dde9e1b71 | ||
|
|
f2474ef941 | ||
|
|
8a4bcacdb2 | ||
|
|
d2c3d5a954 |
16
.coderabbit.yaml
Normal file
16
.coderabbit.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: true
|
||||||
|
review_status: true
|
||||||
|
collapse_walkthrough: true
|
||||||
|
poem: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
drafts: false
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -87,7 +87,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -98,6 +97,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
49
.github/workflows/tests-nightly.yml
vendored
49
.github/workflows/tests-nightly.yml
vendored
@@ -106,6 +106,13 @@ jobs:
|
|||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -130,3 +137,45 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
docker-e2e-multigpu-tests:
|
||||||
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 120
|
||||||
|
needs: [pre-commit, pytest, docker-e2e-tests]
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
num_gpus: 2
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==1.0.2 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.multigpu
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ coverage:
|
|||||||
only_pulls: true
|
only_pulls: true
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
|
informational: true
|
||||||
patch:
|
patch:
|
||||||
default:
|
default:
|
||||||
# basic
|
# basic
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ huggingface_hub>=0.33.0
|
|||||||
peft==0.16.0
|
peft==0.16.0
|
||||||
transformers==4.53.2
|
transformers==4.53.2
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.8.1
|
accelerate==1.9.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.19.1
|
trl==0.19.1
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def do_quantize(
|
|||||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||||
if weight_dtype := cli_args.get("weight_dtype"):
|
if weight_dtype := cli_args.get("weight_dtype"):
|
||||||
weight_dtype = TorchIntDType[weight_dtype]
|
weight_dtype = TorchIntDType[weight_dtype]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
|
|||||||
)
|
)
|
||||||
return ex.tokenized(model_transform)
|
return ex.tokenized(model_transform)
|
||||||
|
|
||||||
process_or_cpu_count: int = (
|
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
|
||||||
)
|
|
||||||
num_proc = min(32, process_or_cpu_count)
|
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
num_proc=num_proc,
|
num_proc=process_count,
|
||||||
keep_in_memory=keep_in_memory,
|
keep_in_memory=keep_in_memory,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
desc="Tokenizing Chats",
|
desc="Tokenizing Chats",
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_blocklist_args_kwargs(cls) -> list[str]:
|
def get_blocklist_args_kwargs(cls) -> list[str]:
|
||||||
return ["dataset_num_proc", "max_length"]
|
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""Module containing Dataset functionality"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
@@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
):
|
):
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
desc="Strategy Filtering Rows",
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
keep_in_memory=self.keep_in_memory,
|
keep_in_memory=self.keep_in_memory,
|
||||||
desc="Tokenizing Prompts",
|
desc="Tokenizing Prompts",
|
||||||
|
|||||||
@@ -41,3 +41,13 @@ class CutCrossEntropyArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_chunked_cross_entropy_not_set(cls, data):
|
||||||
|
if data.get("chunked_cross_entropy"):
|
||||||
|
raise ValueError(
|
||||||
|
"Cut Cross Entropy does not support chunked cross entropy. "
|
||||||
|
"Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
0
src/axolotl/loaders/adapters/__init__.py
Normal file
0
src/axolotl/loaders/adapters/__init__.py
Normal file
@@ -163,15 +163,6 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
|
||||||
# Check if the model is a GraniteConfig object
|
|
||||||
if hasattr(self, 'model') and self.model.__class__.__name__ == "GraniteConfig":
|
|
||||||
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
|
|
||||||
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
|
|
||||||
LOG.error("Please use a different model type or ensure the model is properly configured.")
|
|
||||||
LOG.error("Setting trust_remote_code=True might help if the model requires custom code.")
|
|
||||||
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type or set trust_remote_code=True.")
|
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -225,27 +216,15 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _resize_token_embeddings(self):
|
def _resize_token_embeddings(self):
|
||||||
"""Resize token embeddings if needed."""
|
"""Resize token embeddings if needed."""
|
||||||
# Skip if model doesn't have the necessary methods
|
|
||||||
if not hasattr(self.model, "get_input_embeddings"):
|
|
||||||
LOG.warning("Model does not have get_input_embeddings method, skipping token embedding resize")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if get_input_embeddings returns None
|
|
||||||
input_embeddings = self.model.get_input_embeddings()
|
|
||||||
if input_embeddings is None:
|
|
||||||
LOG.warning("Model's get_input_embeddings returned None, skipping token embedding resize")
|
|
||||||
return
|
|
||||||
|
|
||||||
embeddings_len = (
|
embeddings_len = (
|
||||||
math.ceil(len(self.tokenizer) / 32) * 32
|
math.ceil(len(self.tokenizer) / 32) * 32
|
||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
|
if hasattr(self.model, "get_input_embeddings") and (
|
||||||
if hasattr(input_embeddings, "num_embeddings") and (
|
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
input_embeddings.num_embeddings < embeddings_len
|
|
||||||
or (
|
or (
|
||||||
input_embeddings.num_embeddings > embeddings_len
|
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||||
and self.cfg.shrink_embeddings
|
and self.cfg.shrink_embeddings
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -254,24 +233,14 @@ class ModelLoader:
|
|||||||
self.model_config.model_type != "llava"
|
self.model_config.model_type != "llava"
|
||||||
):
|
):
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
|
|
||||||
if hasattr(self.model, "resize_token_embeddings"):
|
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
LOG.warning("Model does not have resize_token_embeddings method, skipping resize")
|
|
||||||
else:
|
|
||||||
if hasattr(self.model, "tie_weights"):
|
|
||||||
self.model.tie_weights()
|
self.model.tie_weights()
|
||||||
|
|
||||||
def _adjust_model_config(self):
|
def _adjust_model_config(self):
|
||||||
# Skip if model doesn't have config attribute
|
|
||||||
if not hasattr(self.model, "config"):
|
|
||||||
LOG.warning("Model does not have config attribute, skipping model config adjustments")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle max_position_embeddings
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.model.config, "max_position_embeddings")
|
hasattr(self.model, "config")
|
||||||
|
and hasattr(self.model.config, "max_position_embeddings")
|
||||||
and self.model.config.max_position_embeddings
|
and self.model.config.max_position_embeddings
|
||||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -281,17 +250,17 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
self.model.config.max_position_embeddings = self.cfg.sequence_len
|
self.model.config.max_position_embeddings = self.cfg.sequence_len
|
||||||
|
|
||||||
# Handle bos_token_id
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.model.config, "bos_token_id")
|
hasattr(self.model, "config")
|
||||||
|
and hasattr(self.model.config, "bos_token_id")
|
||||||
and self.model.config.bos_token_id
|
and self.model.config.bos_token_id
|
||||||
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
|
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
||||||
|
|
||||||
# Handle eos_token_id
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.model.config, "eos_token_id")
|
hasattr(self.model, "config")
|
||||||
|
and hasattr(self.model.config, "eos_token_id")
|
||||||
and self.model.config.eos_token_id
|
and self.model.config.eos_token_id
|
||||||
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
|
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
@@ -323,12 +292,9 @@ class ModelLoader:
|
|||||||
if self.cfg.adapter in ["lora", "qlora"]:
|
if self.cfg.adapter in ["lora", "qlora"]:
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
if hasattr(self.model, "gradient_checkpointing_enable"):
|
|
||||||
self.model.gradient_checkpointing_enable(
|
self.model.gradient_checkpointing_enable(
|
||||||
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have gradient_checkpointing_enable method, skipping gradient checkpointing")
|
|
||||||
|
|
||||||
self._prepare_model_for_quantization()
|
self._prepare_model_for_quantization()
|
||||||
|
|
||||||
@@ -405,14 +371,11 @@ class ModelLoader:
|
|||||||
self.model.is_parallelizable = True
|
self.model.is_parallelizable = True
|
||||||
self.model.model_parallel = True
|
self.model.model_parallel = True
|
||||||
|
|
||||||
if hasattr(self.model, "named_parameters"):
|
|
||||||
if not any(
|
if not any(
|
||||||
param.requires_grad
|
param.requires_grad
|
||||||
for _, param in self.model.named_parameters(recurse=True)
|
for _, param in self.model.named_parameters(recurse=True)
|
||||||
):
|
):
|
||||||
LOG.warning("There are no parameters that require gradient updates")
|
LOG.warning("There are no parameters that require gradient updates")
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have named_parameters attribute, skipping gradient check")
|
|
||||||
|
|
||||||
if self.cfg.flash_optimum:
|
if self.cfg.flash_optimum:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -420,10 +383,7 @@ class ModelLoader:
|
|||||||
self.model = BetterTransformer.transform(self.model)
|
self.model = BetterTransformer.transform(self.model)
|
||||||
|
|
||||||
if self.cfg.adapter is not None:
|
if self.cfg.adapter is not None:
|
||||||
if hasattr(self.model, "device"):
|
|
||||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have device attribute, skipping memory usage logging")
|
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -740,10 +700,6 @@ class ModelLoader:
|
|||||||
and self.model_type != "AutoModelForCausalLM"
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
if self.model_type == "GraniteSpeechConfig" and not hasattr(self.model_config, 'vocab_size'):
|
|
||||||
# Set vocab_size from tokenizer or use a reasonable default
|
|
||||||
self.model_config.vocab_size = getattr(self.model_config, 'vocab_size', 50257)
|
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
@@ -751,21 +707,7 @@ class ModelLoader:
|
|||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.model_type == "GraniteSpeechConfig":
|
|
||||||
# Use the actual model class for Granite Speech
|
|
||||||
self.model = transformers.GraniteSpeechForCausalLM.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if not hasattr(self.model_config, 'vocab_size'):
|
|
||||||
LOG.warning("Model config does not have vocab_size attribute, setting to 50257")
|
|
||||||
self.model_config.vocab_size = 50257
|
|
||||||
|
|
||||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
self.model = getattr(transformers, self.model_type).from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -849,19 +791,13 @@ class ModelLoader:
|
|||||||
dest = {"dtype": dist_dtype}
|
dest = {"dtype": dist_dtype}
|
||||||
if self.cfg.lora_on_cpu:
|
if self.cfg.lora_on_cpu:
|
||||||
dest["device"] = "cpu"
|
dest["device"] = "cpu"
|
||||||
|
|
||||||
# Check if the model has named_modules attribute
|
|
||||||
if not hasattr(self.model, "named_modules"):
|
|
||||||
LOG.warning("Model does not have named_modules attribute, skipping embedding dtype conversion")
|
|
||||||
return
|
|
||||||
|
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
if before_kbit_train_or_finetune:
|
if before_kbit_train_or_finetune:
|
||||||
if name.endswith(".gate"):
|
if name.endswith(".gate"):
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
if self.model_config.model_type == "btlm" and "lm_head" in name:
|
if self.model_config.model_type == "btlm":
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
|
|||||||
@@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
# Qwen base only has single token, so we need to set the special tokens
|
||||||
if cfg.is_qwen_derived_model:
|
# the following check is for Qwen1 base models
|
||||||
|
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
|
||||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||||
for attr_name in token_ids:
|
for attr_name in token_ids:
|
||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
"loggers": {
|
"loggers": {
|
||||||
"axolotl": {
|
"axolotl": {
|
||||||
"handlers": ["color_console"],
|
"handlers": ["color_console"],
|
||||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
|
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||||
"propagate": False,
|
"propagate": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -151,6 +151,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return MllamaTextSelfAttention
|
return MllamaTextSelfAttention
|
||||||
|
|
||||||
|
if model_type == "llama4":
|
||||||
|
from transformers.models.llama4.modeling_llama4 import Llama4TextAttention
|
||||||
|
|
||||||
|
return Llama4TextAttention
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
|||||||
@@ -80,15 +80,7 @@ def setup_model_and_tokenizer(
|
|||||||
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||||
model, peft_config = model_loader.load()
|
model, peft_config = model_loader.load()
|
||||||
|
if model.generation_config is not None:
|
||||||
# Check if model is actually a GraniteConfig object
|
|
||||||
if model.__class__.__name__ == "GraniteConfig":
|
|
||||||
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
|
|
||||||
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
|
|
||||||
LOG.error("Please use a different model type or ensure the model is properly configured.")
|
|
||||||
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type.")
|
|
||||||
|
|
||||||
if hasattr(model, "generation_config") and model.generation_config is not None:
|
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
# Apply freezing if specified
|
# Apply freezing if specified
|
||||||
@@ -98,10 +90,7 @@ def setup_model_and_tokenizer(
|
|||||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||||
for param in cfg.unfrozen_parameters
|
for param in cfg.unfrozen_parameters
|
||||||
):
|
):
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have enable_input_require_grads method, skipping")
|
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
@@ -257,12 +246,9 @@ def save_trained_model(
|
|||||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||||
|
|
||||||
# Post training module hooks
|
# Post training module hooks
|
||||||
if hasattr(model, "named_modules"):
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, "_post_training"):
|
if hasattr(module, "_post_training"):
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
|
|
||||||
|
|
||||||
# handle QAT
|
# handle QAT
|
||||||
if cfg.qat:
|
if cfg.qat:
|
||||||
@@ -322,17 +308,11 @@ def save_trained_model(
|
|||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
if hasattr(trainer.model, "save_pretrained"):
|
|
||||||
trainer.model.save_pretrained(
|
trainer.model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
|
|
||||||
|
|
||||||
if hasattr(model, "save_pretrained"):
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have save_pretrained method, skipping save")
|
|
||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
@@ -418,10 +398,7 @@ def save_initial_configs(
|
|||||||
tokenizer.save_pretrained(str(output_dir))
|
tokenizer.save_pretrained(str(output_dir))
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
if hasattr(model.config, "save_pretrained"):
|
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
else:
|
|
||||||
LOG.warning("Model config does not have save_pretrained method, skipping config save")
|
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||||
@@ -484,12 +461,9 @@ def handle_untrained_tokens_fix(
|
|||||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
if hasattr(model, "save_pretrained"):
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
LOG.warning("Model does not have save_pretrained method, skipping save")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
|
|||||||
@@ -798,7 +798,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if is_main_process():
|
if state.is_world_process_zero:
|
||||||
try:
|
try:
|
||||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||||
with NamedTemporaryFile(
|
with NamedTemporaryFile(
|
||||||
|
|||||||
@@ -148,8 +148,6 @@ def normalize_config(cfg):
|
|||||||
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
|
||||||
|
|
||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
cfg.base_model_config = cfg.base_model
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
|
|||||||
@@ -410,9 +410,8 @@ def save_preprocessed_dataset(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||||
if isinstance(dataset, IterableDataset):
|
|
||||||
num_workers = cfg.dataset_processes
|
num_workers = cfg.dataset_processes
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
ds_from_iter = Dataset.from_generator(
|
ds_from_iter = Dataset.from_generator(
|
||||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||||
features=dataset.features,
|
features=dataset.features,
|
||||||
@@ -423,10 +422,20 @@ def save_preprocessed_dataset(
|
|||||||
"num_workers": [num_workers] * num_workers,
|
"num_workers": [num_workers] * num_workers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
ds_from_iter.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Pushing merged prepared dataset to Huggingface hub at "
|
"Pushing merged prepared dataset to Huggingface hub at "
|
||||||
@@ -460,13 +469,13 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
|
|||||||
):
|
):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return load_from_disk(str(prepared_ds_path))
|
return load_from_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
|
|||||||
UIntXWeightOnlyConfig,
|
UIntXWeightOnlyConfig,
|
||||||
_is_linear,
|
_is_linear,
|
||||||
)
|
)
|
||||||
|
from transformers import TorchAoConfig
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchIntDType
|
||||||
|
|
||||||
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
quantize_(model, linear_ptq_config)
|
quantize_(model, linear_ptq_config)
|
||||||
|
quantization_config = TorchAoConfig(linear_ptq_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
|
quantization_config.include_input_output_embeddings = True
|
||||||
embedding_quantize_config = get_ptq_config(
|
embedding_quantize_config = get_ptq_config(
|
||||||
weight_dtype=weight_dtype,
|
weight_dtype=weight_dtype,
|
||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
|
|||||||
embedding_quantize_config,
|
embedding_quantize_config,
|
||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
)
|
)
|
||||||
|
model.config.quantization_config = quantization_config
|
||||||
|
|
||||||
|
|
||||||
def convert_qat_model_for_ptq(
|
def convert_qat_model_for_ptq(
|
||||||
|
|||||||
@@ -193,6 +193,12 @@ class AxolotlInputConfig(
|
|||||||
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
||||||
)
|
)
|
||||||
skip_prepare_dataset: bool | None = False
|
skip_prepare_dataset: bool | None = False
|
||||||
|
num_dataset_shards_to_save: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of shards to save the prepared dataset"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
pretraining_dataset: (
|
pretraining_dataset: (
|
||||||
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
||||||
@@ -203,11 +209,12 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_processes: int | None = Field(
|
dataset_processes: int | None = Field(
|
||||||
default=min(
|
default=None,
|
||||||
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
|
|
||||||
), # type: ignore[type-var]
|
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
"description": (
|
||||||
|
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||||
|
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_exact_deduplication: bool | None = Field(
|
dataset_exact_deduplication: bool | None = Field(
|
||||||
@@ -1199,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["dataloader_prefetch_factor"] = 256
|
data["dataloader_prefetch_factor"] = 256
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def default_dataset_processes(cls, data):
|
||||||
|
if data.get("dataset_processes") is None:
|
||||||
|
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||||
|
data["dataset_processes"] = int(axolotl_dataset_processes)
|
||||||
|
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||||
|
data["dataset_processes"] = int(runpod_cpu_count)
|
||||||
|
else:
|
||||||
|
data["dataset_processes"] = os.cpu_count()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user