Compare commits
6 Commits
llmcompres
...
lora-quant
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a22d16842 | ||
|
|
fee3c13bb5 | ||
|
|
996fc124e5 | ||
|
|
e963990ad7 | ||
|
|
c3f2b1c5c2 | ||
|
|
6ba5c0ed2c |
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
90
.runpod/tests.json
Normal file
90
.runpod/tests.json
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
{
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"name": "quick_smoke_test_sft",
|
||||||
|
"input": {
|
||||||
|
"user_id": "user",
|
||||||
|
"model_id": "llama-test",
|
||||||
|
"run_id": "llama-test",
|
||||||
|
"credentials": {
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"hf_token": ""
|
||||||
|
},
|
||||||
|
"args": {
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"load_in_4bit": true,
|
||||||
|
"strict": false,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"output_dir": "./outputs/lora-out",
|
||||||
|
"sequence_len": 4096,
|
||||||
|
"sample_packing": true,
|
||||||
|
"eval_sample_packing": false,
|
||||||
|
"pad_to_sequence_len": true,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": true,
|
||||||
|
"lora_modules_to_save": [
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head"
|
||||||
|
],
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"train_on_inputs": false,
|
||||||
|
"group_by_length": false,
|
||||||
|
"bf16": "auto",
|
||||||
|
"tf32": true,
|
||||||
|
"gradient_checkpointing": true,
|
||||||
|
"logging_steps": 1,
|
||||||
|
"flash_attention": true,
|
||||||
|
"warmup_steps": 1,
|
||||||
|
"evals_per_epoch": 1,
|
||||||
|
"eval_max_new_tokens": 128,
|
||||||
|
"saves_per_epoch": 1,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>"
|
||||||
|
},
|
||||||
|
"max_steps": 20
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"timeout": 100000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"config": {
|
||||||
|
"gpuTypeId": "NVIDIA GeForce RTX 4090",
|
||||||
|
"gpuCount": 1,
|
||||||
|
"containerDiskInGb": 200,
|
||||||
|
"env": [
|
||||||
|
{
|
||||||
|
"key": "TOKENIZER",
|
||||||
|
"value": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "DISABLE_LOG_STATS",
|
||||||
|
"value": "true"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"allowedCudaVersions": [
|
||||||
|
"12.8",
|
||||||
|
"12.7",
|
||||||
|
"12.6",
|
||||||
|
"12.5",
|
||||||
|
"12.4"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ accelerate==1.6.0
|
|||||||
datasets==3.5.0
|
datasets==3.5.0
|
||||||
deepspeed>=0.15.4
|
deepspeed>=0.15.4
|
||||||
trl==0.17.0
|
trl==0.17.0
|
||||||
hf_xet==1.0.0
|
hf_xet==1.1.0
|
||||||
hqq==0.2.5
|
hqq==0.2.5
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -2,4 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
configure_logging()
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
|
|||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
def load_cfg(
|
||||||
|
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||||
|
) -> DictDefault:
|
||||||
"""
|
"""
|
||||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
various setup.
|
various setup.
|
||||||
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
Returns:
|
Returns:
|
||||||
`DictDefault` mapping configuration keys to values.
|
`DictDefault` mapping configuration keys to values.
|
||||||
"""
|
"""
|
||||||
config = check_remote_config(config)
|
if isinstance(config, (str, Path)):
|
||||||
if Path(config).is_dir():
|
config = check_remote_config(config)
|
||||||
config = choose_config(Path(config))
|
if Path(config).is_dir():
|
||||||
|
config = choose_config(Path(config))
|
||||||
|
|
||||||
# Load the config from the yaml file
|
# Load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
|
cfg.axolotl_config_path = config
|
||||||
|
else:
|
||||||
|
cfg = config
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
temp_file.write(yaml.dump(config.to_dict()))
|
||||||
|
temp_file.close()
|
||||||
|
cfg.axolotl_config_path = temp_file.name
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
else:
|
else:
|
||||||
cfg[k] = kwargs[k]
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
cfg.axolotl_config_path = config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_props = torch.cuda.get_device_properties("cuda")
|
device_props = torch.cuda.get_device_properties("cuda")
|
||||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||||
|
|||||||
@@ -20,11 +20,9 @@ from transformers import (
|
|||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
|||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets, calling
|
Loads one or more training or evaluation datasets, calling
|
||||||
@@ -64,7 +64,8 @@ def load_datasets(
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
preprocess_iterable = (
|
preprocess_iterable = (
|
||||||
hasattr(cli_args, "iterable")
|
cli_args
|
||||||
|
and hasattr(cli_args, "iterable")
|
||||||
and cli_args.iterable is not None
|
and cli_args.iterable is not None
|
||||||
and cli_args.iterable
|
and cli_args.iterable
|
||||||
)
|
)
|
||||||
@@ -76,7 +77,7 @@ def load_datasets(
|
|||||||
preprocess_iterable=preprocess_iterable,
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if cli_args and (
|
||||||
cli_args.debug
|
cli_args.debug
|
||||||
or cfg.debug
|
or cfg.debug
|
||||||
or cli_args.debug_text_only
|
or cli_args.debug_text_only
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
self.cfg.max_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.train import (
|
from axolotl.train import (
|
||||||
TrainDatasetMeta,
|
TrainDatasetMeta,
|
||||||
setup_model_and_tokenizer,
|
setup_model_and_tokenizer,
|
||||||
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,13 +55,16 @@ def dequantize(
|
|||||||
target_device = W.device
|
target_device = W.device
|
||||||
|
|
||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
|
nested = False
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
blocksize = quant_state.blocksize
|
blocksize = quant_state.blocksize
|
||||||
offset = quant_state.offset.to(target_device)
|
if quant_state.nested:
|
||||||
|
nested = True
|
||||||
|
offset = quant_state.offset.to(target_device)
|
||||||
state2 = quant_state.state2
|
state2 = quant_state.state2
|
||||||
absmax2 = state2.absmax.to(target_device)
|
absmax2 = state2.absmax.to(target_device)
|
||||||
code2 = state2.code.to(target_device)
|
code2 = state2.code.to(target_device)
|
||||||
@@ -115,7 +118,8 @@ def dequantize(
|
|||||||
ctypes.c_int(n_elements_absmax),
|
ctypes.c_int(n_elements_absmax),
|
||||||
)
|
)
|
||||||
|
|
||||||
out_absmax += offset
|
if nested:
|
||||||
|
out_absmax += offset
|
||||||
|
|
||||||
# Choose appropriate dequantization function
|
# Choose appropriate dequantization function
|
||||||
fx = (
|
fx = (
|
||||||
|
|||||||
@@ -12,10 +12,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/trainer/__init__.py
Normal file
0
src/axolotl/monkeypatch/trainer/__init__.py
Normal file
@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
|
|||||||
SequenceParallelContextManager,
|
SequenceParallelContextManager,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -42,7 +41,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
|
|||||||
else:
|
else:
|
||||||
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
if cfg.fp16 is None:
|
if cfg.fp16 is None and not cfg.float16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
|
|||||||
@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
|
num_completions_to_print: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
|
||||||
|
},
|
||||||
|
)
|
||||||
sync_ref_model: bool | None = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
elif cfg.fp16:
|
elif cfg.fp16:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
else:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||||
|
|
||||||
|
|
||||||
def prepare_opinionated_env(cfg):
|
def prepare_opinionated_env(cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user