Compare commits

...

4 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
132 changed files with 479 additions and 301 deletions

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -3,7 +3,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.1.0 rev: 7.1.2
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pylint-dev/pylint - repo: https://github.com/pylint-dev/pylint
rev: c8c96d20cde3552a79858c7456bb1483bf83d633 rev: v3.3.6
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0 rev: v1.15.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.8.3
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -1,6 +1,7 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,6 +1,7 @@
""" """
helper script to parse chat datasets into a usable yaml helper script to parse chat datasets into a usable yaml
""" """
import click import click
import yaml import yaml
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy.""" """Script to output the correct installation command for cut-cross-entropy."""
import importlib.util import importlib.util
import sys import sys
@@ -17,12 +18,12 @@ if v < V("2.4.0"):
cce_spec = importlib.util.find_spec("cut_cross_entropy") cce_spec = importlib.util.find_spec("cut_cross_entropy")
uninstall_prefix = "" UNINSTALL_PREFIX = ""
if cce_spec: if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"): if not importlib.util.find_spec("cut_cross_entropy.transformers"):
uninstall_prefix = "pip uninstall -y cut-cross-entropy && " UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print( print(
uninstall_prefix UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
) )

View File

@@ -1,6 +1,7 @@
""" """
launch axolotl in supported cloud platforms launch axolotl in supported cloud platforms
""" """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union

View File

@@ -1,6 +1,7 @@
""" """
base class for cloud platforms from cli base class for cloud platforms from cli
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
""" """
Modal Cloud support from CLI Modal Cloud support from CLI
""" """
import copy import copy
import json import json
import os import os

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging import logging

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None): def strip_optional_type(field_type: type | str | None):
""" """
Extracts the non-`None` type from an `Optional` / `Union` type. Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json import json
import sys import sys

View File

@@ -1,6 +1,7 @@
""" """
ChatML transformation functions for MessageContents ChatML transformation functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
Llama 3.x chat formatting functions for MessageContents Llama 3.x chat formatting functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
shared functions for format transforms shared functions for format transforms
""" """
from axolotl.core.chat.messages import MessageContents, Messages from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
internal message representations of chat messages internal message representations of chat messages
""" """
import json import json
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
chat dataset module chat dataset module
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping, Union

View File

@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {} training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None: if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[ training_arguments_kwargs["include_tokens_per_second"] = (
"include_tokens_per_second" self.cfg.include_tokens_per_second
] = self.cfg.include_tokens_per_second )
if self.cfg.bf16 == "full": if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None: if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[ training_arguments_kwargs["lr_quadratic_warmup"] = (
"lr_quadratic_warmup" self.cfg.lr_quadratic_warmup
] = self.cfg.lr_quadratic_warmup )
if self.cfg.adam_beta1: if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.dataloader_drop_last is not None: if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_drop_last"] = (
"dataloader_drop_last" self.cfg.dataloader_drop_last
] = self.cfg.dataloader_drop_last )
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[ training_arguments_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval: if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs["metric_for_best_model"] = (
"metric_for_best_model" self.cfg.metric_for_best_model
] = self.cfg.metric_for_best_model )
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_backend"] = (
"torch_compile_backend" self.cfg.torch_compile_backend
] = self.cfg.torch_compile_backend )
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_mode"] = (
"torch_compile_mode" self.cfg.torch_compile_mode
] = self.cfg.torch_compile_mode )
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb: if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None: if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[ training_arguments_kwargs["ddp_broadcast_buffers"] = (
"ddp_broadcast_buffers" self.cfg.ddp_broadcast_buffers
] = self.cfg.ddp_broadcast_buffers )
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[ training_arguments_kwargs["per_device_train_batch_size"] = (
"per_device_train_batch_size" self.cfg.micro_batch_size
] = self.cfg.micro_batch_size )
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_arguments_kwargs[ training_arguments_kwargs["per_device_eval_batch_size"] = (
"per_device_eval_batch_size" self.cfg.eval_batch_size
] = self.cfg.eval_batch_size )
if self.cfg.auto_find_batch_size is not None: if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["auto_find_batch_size"] = (
"auto_find_batch_size" self.cfg.auto_find_batch_size
] = self.cfg.auto_find_batch_size )
training_arguments_kwargs[ training_arguments_kwargs["gradient_accumulation_steps"] = (
"gradient_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs[ training_arguments_kwargs["eval_accumulation_steps"] = (
"eval_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs["alternate_lr_scheduler_type"] = (
"alternate_lr_scheduler_type" self.cfg.lr_scheduler
] = self.cfg.lr_scheduler )
else: else:
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["cosine_constant_lr_ratio"] = (
"cosine_constant_lr_ratio" self.cfg.cosine_constant_lr_ratio
] = self.cfg.cosine_constant_lr_ratio )
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_bin_size"] = (
"sample_packing_bin_size" self.cfg.sample_packing_bin_size
] = self.cfg.sample_packing_bin_size )
if self.cfg.sample_packing_group_size is not None: if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_group_size"] = (
"sample_packing_group_size" self.cfg.sample_packing_group_size
] = self.cfg.sample_packing_group_size )
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_efficiency"] = (
"sample_packing_efficiency" self.cfg.sample_packing_eff_est
] = self.cfg.sample_packing_eff_est )
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs["relora_warmup_steps"] = (
"relora_warmup_steps" self.cfg.relora_warmup_steps
] = self.cfg.relora_warmup_steps )
if self.cfg.relora_anneal_steps: if self.cfg.relora_anneal_steps:
training_arguments_kwargs[ training_arguments_kwargs["relora_anneal_steps"] = (
"relora_anneal_steps" self.cfg.relora_anneal_steps
] = self.cfg.relora_anneal_steps )
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio:
training_arguments_kwargs[ training_arguments_kwargs["relora_prune_ratio"] = (
"relora_prune_ratio" self.cfg.relora_prune_ratio
] = self.cfg.relora_prune_ratio )
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[ training_arguments_kwargs["lisa_step_interval"] = (
"lisa_step_interval" self.cfg.lisa_step_interval
] = self.cfg.lisa_step_interval )
training_arguments_kwargs[ training_arguments_kwargs["lisa_layers_attribute"] = (
"lisa_layers_attribute" self.cfg.lisa_layers_attribute
] = self.cfg.lisa_layers_attribute )
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[ training_arguments_kwargs["neftune_noise_alpha"] = (
"neftune_noise_alpha" self.cfg.neftune_noise_alpha
] = self.cfg.neftune_noise_alpha )
trainer_kwargs = {} trainer_kwargs = {}
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules: if self.cfg.optim_target_modules:
training_arguments_kwargs[ training_arguments_kwargs["optim_target_modules"] = (
"optim_target_modules" self.cfg.optim_target_modules
] = self.cfg.optim_target_modules )
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["loraplus_lr_embedding"] = (
"loraplus_lr_embedding" self.cfg.loraplus_lr_embedding
] = self.cfg.loraplus_lr_embedding )
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs["accelerator_config"] = (
"accelerator_config" self.cfg.accelerator_config
] = self.cfg.accelerator_config )
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_zscore_base_temp"] = (
"kd_zscore_base_temp" self.cfg.kd_zscore_base_temp
] = self.cfg.kd_zscore_base_temp )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_top_k_before_softmax"] = (
"kd_top_k_before_softmax" self.cfg.kd_top_k_before_softmax
] = self.cfg.kd_top_k_before_softmax )
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_args_kwargs[ training_args_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
else: else:
training_args_kwargs["remove_unused_columns"] = False training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[ training_args_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[ training_args_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[ training_args_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
else: else:
training_args_kwargs["gradient_checkpointing_kwargs"] = { training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False "use_reentrant": False
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None: if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None: if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[ training_args_kwargs["use_logits_to_keep"] = (
"use_logits_to_keep" self.cfg.dpo_use_logits_to_keep
] = self.cfg.dpo_use_logits_to_keep )
for blocklist_key in blocklist_args_kwargs: for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs: if blocklist_key in training_args_kwargs:
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[ dpo_trainer_kwargs["precompute_ref_log_probs"] = (
"precompute_ref_log_probs" self.cfg.precompute_ref_log_probs
] = self.cfg.precompute_ref_log_probs )
if self.cfg.rl == "grpo": if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
sampler = self._get_train_sampler() sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
if isinstance(eval_sampler, BatchSampler): if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
""" """
DPO Specific Strategy for training DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import DPOConfig from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
""" """
DPO trainer for axolotl DPO trainer for axolotl
""" """
import gc import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union

View File

@@ -45,9 +45,9 @@ class GRPOStrategy:
) )
if trl.vllm_gpu_memory_utilization: if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[ grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
"vllm_gpu_memory_utilization" trl.vllm_gpu_memory_utilization
] = trl.vllm_gpu_memory_utilization )
if trl.vllm_max_model_len: if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
@@ -86,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = (
"reward_processing_classes" cfg.trl.reward_processing_classes
] = cfg.trl.reward_processing_classes )
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -1,6 +1,7 @@
""" """
Axolotl Specific Training Args Axolotl Specific Training Args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import GRPOConfig from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
""" """
module for TRL PPO training module for TRL PPO training
""" """
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import PPOTrainer from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
extra axolotl specific training args extra axolotl specific training args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional

View File

@@ -1,6 +1,7 @@
""" """
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
""" """
config args for grokfast plugin config args for grokfast plugin
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[ kd_ce_alpha: Optional[float] = (
float None # loss coefficient for cross-entropy loss during KD
] = None # loss coefficient for cross-entropy loss during KD )
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[bool] = (
bool None # whether to sample top k before softmax during KD
] = None # whether to sample top k before softmax during KD )

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters: if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[ kwargs["fused_linear_cross_entropy"] = (
"fused_linear_cross_entropy" cfg.liger_fused_linear_cross_entropy
] = cfg.liger_fused_linear_cross_entropy )
if "rms_norm" in liger_fn_sig.parameters: if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters: if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
""" """
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
""" """
Module for handling lm eval harness input arguments. Module for handling lm eval harness input arguments.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
""" """
axolotl CLI for running lm_eval tasks axolotl CLI for running lm_eval tasks
""" """
import subprocess # nosec import subprocess # nosec
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from typing import Callable from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement # pylint: disable=invalid-name,global-statement
import ctypes import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -1,6 +1,7 @@
""" """
HF Transformers MambaConfig HF Transformers MambaConfig
""" """
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
""" """
Monkeypatch for Vision Llama for FA2 support Monkeypatch for Vision Llama for FA2 support
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True True
) )
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MllamaTextCrossFlashAttention2
] = MllamaTextCrossFlashAttention2 )
# fallback to SDPA # fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[ MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] )

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes""" """monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access # pylint: disable=protected-access
import torch import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
""" """
from typing import Optional from typing import Optional
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model""" """Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob import glob
import json import json
import logging import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"): if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path)) in_tensors = st.load_file(str(Path(model_src) / shard_path))
else: else:
in_tensors = torch.load(Path(model_src) / shard_path) in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors: if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"] in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """ """PyTorch StableLM Epoch model."""
import importlib import importlib
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
fix for FSDP optimizer save in trainer w 4.47.0 fix for FSDP optimizer save in trainer w 4.47.0
""" """
import inspect import inspect
import logging import logging

View File

@@ -1,6 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
""" """
Fused MLP layer for incrementally improved training efficiency Fused MLP layer for incrementally improved training efficiency
""" """
import torch import torch
from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
""" """
Prompt strategies loader for alpaca instruction datasets with system prompts Prompt strategies loader for alpaca instruction datasets with system prompts
""" """
from typing import Generator, Tuple, Union from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies""" """Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
module for DPO style dataset transform strategies module for DPO style dataset transform strategies
""" """
from functools import partial from functools import partial
from ..base import load as load_base from ..base import load as load_base

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample return sample
@@ -52,9 +52,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample return sample
@@ -53,9 +53,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample

View File

@@ -1,4 +1,5 @@
"""Module for plain input/output prompt pairs""" """Module for plain input/output prompt pairs"""
from typing import Generator, Tuple from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""Module for inspect jinja templates for the variables they use""" """Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes from jinja2 import Environment, meta, nodes

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for chatml KTO strategies for chatml
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for llama-3 chat template KTO strategies for llama-3 chat template
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
User-defined KTO strategies User-defined KTO strategies
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -1,6 +1,7 @@
""" """
Chat dataset wrapping strategy for new internal messages representations Chat dataset wrapping strategy for new internal messages representations
""" """
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset from axolotl.core.datasets.chat import TokenizedChatDataset

View File

@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments. Not suited/tested for multiple-turn conversations without further adjustments.
""" """
from typing import Generator, Union from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""chatml prompt tokenization strategy for ORPO""" """chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""pretraining prompt strategies""" """pretraining prompt strategies"""
from typing import Generator from typing import Generator
from transformers import BatchEncoding from transformers import BatchEncoding

View File

@@ -406,9 +406,7 @@ def handle_untrained_tokens_fix(
) )
def setup_model_and_trainer( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder, HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel, PeftModel | PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,

View File

@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2: if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"PYTORCH_CUDA_ALLOC_CONF" "expandable_segments:True,roundup_power2_divisions:16"
] = "expandable_segments:True,roundup_power2_divisions:16" )

View File

@@ -1,4 +1,5 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import torch import torch

View File

@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_refs.extend(combined_bench_names[bench_name]["refs"]) bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"]) bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score): if not pd.isna(bench_score):
results[ results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
f"{bench_split}_bench_accuracy_{bench_name}" bench_score
] = bench_score )
bench_scores.append(bench_score) bench_scores.append(bench_score)
else: else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0

View File

@@ -1,4 +1,5 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging import logging
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile

View File

@@ -1,4 +1,5 @@
"""callback to calculate perplexity as an evaluation metric.""" """callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch

View File

@@ -1,6 +1,7 @@
""" """
HF Trainer callback for creating pytorch profiling snapshots HF Trainer callback for creating pytorch profiling snapshots
""" """
from pathlib import Path from pathlib import Path
from pickle import dump # nosec B403 from pickle import dump # nosec B403

View File

@@ -2,6 +2,7 @@
This module provides functionality for selecting chat templates based on user choices. This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation. These templates are used for formatting messages in a conversation.
""" """
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional

View File

@@ -1,6 +1,7 @@
""" """
shared axolotl collators for multipack, mamba, multimodal shared axolotl collators for multipack, mamba, multimodal
""" """
from .batching import ( # noqa: F401 from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,

View File

@@ -1,4 +1,5 @@
""" """
basic shared collator constants basic shared collator constants
""" """
IGNORE_INDEX = -100 IGNORE_INDEX = -100

View File

@@ -1,6 +1,7 @@
""" """
collators for Mamba collators for Mamba
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Sequence from typing import Dict, Sequence

View File

@@ -18,7 +18,11 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase, AxolotlInputConfig as AxolotlInputConfigBase,
) )
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset from axolotl.utils.config.models.input.v0_4_1 import (
DPODataset,
KTODataset,
SFTDataset,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config

View File

@@ -200,12 +200,12 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
field_messages: Optional[str] = None field_messages: Optional[str] = None
message_field_role: Optional[ message_field_role: Optional[str] = (
str None # deprecated, use message_property_mappings
] = None # deprecated, use message_property_mappings )
message_field_content: Optional[ message_field_content: Optional[str] = (
str None # deprecated, use message_property_mappings
] = None # deprecated, use message_property_mappings )
message_property_mappings: Optional[Dict[str, str]] = None message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
@@ -505,9 +505,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: Optional[float] = None embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = (
Union[OptimizerNames, CustomSupportedOptimizers] OptimizerNames.ADAMW_TORCH_FUSED
] = OptimizerNames.ADAMW_TORCH_FUSED )
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."}, json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -699,9 +699,9 @@ class AxolotlInputConfig(
reward_model: Optional[bool] = None reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None num_labels: Optional[int] = None
dpo_use_weighting: Optional[ dpo_use_weighting: Optional[bool] = (
bool None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. )
dpo_use_logits_to_keep: Optional[bool] = None dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[ datasets: Optional[
@@ -780,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype] # torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[ gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = (
Union[Literal["unsloth", "offload"], bool] Field(default=False)
] = Field(default=False) )
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
@@ -894,9 +894,9 @@ class AxolotlInputConfig(
kto_undesirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None
rl_beta: Optional[float] = None rl_beta: Optional[float] = None
max_memory: Optional[ max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = (
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] None
] = None )
gpu_memory_limit: Optional[Union[int, str]] = None gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None low_cpu_mem_usage: Optional[bool] = None

View File

@@ -1,4 +1,5 @@
"""module for gpu capabilities""" """module for gpu capabilities"""
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -1,6 +1,7 @@
""" """
Data processing modules Data processing modules
""" """
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,

View File

@@ -1,6 +1,7 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os import os
import pickle # nosec import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager

View File

@@ -1,10 +1,13 @@
""" """
utils to get GPU info for the current environment utils to get GPU info for the current environment
""" """
from accelerate.utils.environment import ( from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
) )
from accelerate.utils.environment import get_gpu_info from accelerate.utils.environment import (
get_gpu_info,
)
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():

View File

@@ -1,6 +1,7 @@
""" """
module to freeze/unfreeze parameters by name module to freeze/unfreeze parameters by name
""" """
import logging import logging
import re import re
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union

View File

@@ -1,4 +1,5 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )

View File

@@ -1,6 +1,7 @@
""" """
module to handle loading model on cpu/meta device for FSDP module to handle loading model on cpu/meta device for FSDP
""" """
import os import os
import time import time
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
@@ -45,13 +46,13 @@ def _replace_linear(
if isinstance(module, torch.nn.Linear) and name not in skip_modules: if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit): if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access model._modules[name] = ( # pylint: disable=protected-access
name linear_replacement(
] = linear_replacement( module.in_features,
module.in_features, module.out_features,
module.out_features, module.bias is not None,
module.bias is not None, **kwargs,
**kwargs, )
) )
else: else:
raise ValueError( raise ValueError(

View File

@@ -741,9 +741,9 @@ class ModelLoader:
) )
else: else:
if self.cfg.gptq_disable_exllama is not None: if self.cfg.gptq_disable_exllama is not None:
self.model_config.quantization_config[ self.model_config.quantization_config["disable_exllama"] = (
"disable_exllama" self.cfg.gptq_disable_exllama
] = self.cfg.gptq_disable_exllama )
self.model_kwargs["quantization_config"] = GPTQConfig( self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config **self.model_config.quantization_config
) )

View File

@@ -4,6 +4,7 @@ Copied from https://github.com/iShohei220/adopt
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024) ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
""" """
# mypy: ignore-errors # mypy: ignore-errors
# pylint: skip-file # pylint: skip-file
# flake8: noqa # flake8: noqa

View File

@@ -1,5 +1,6 @@
""" """
axolotl samplers module axolotl samplers module
""" """
from .multipack import MultipackBatchSampler # noqa: F401 from .multipack import MultipackBatchSampler # noqa: F401
from .utils import get_dataset_lengths # noqa: F401 from .utils import get_dataset_lengths # noqa: F401

View File

@@ -1,6 +1,7 @@
""" """
helper util to calculate dataset lengths helper util to calculate dataset lengths
""" """
import numpy as np import numpy as np

View File

@@ -1,4 +1,5 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math import math
from functools import partial from functools import partial

View File

@@ -538,9 +538,9 @@ def setup_fsdp_envs(cfg):
if cfg.fsdp_config.fsdp_auto_wrap_policy: if cfg.fsdp_config.fsdp_auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
os.environ[ os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
"FSDP_TRANSFORMER_CLS_TO_WRAP" cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap )
def prepare_optim_env(cfg): def prepare_optim_env(cfg):

View File

@@ -1,6 +1,7 @@
""" """
dynamic requirements for axolotl dynamic requirements for axolotl
""" """
import platform import platform
import re import re
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" """pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch

View File

@@ -1,6 +1,7 @@
""" """
unit tests for generating sweep configurations unit tests for generating sweep configurations
""" """
from axolotl.cli.main import generate_sweep_configs from axolotl.cli.main import generate_sweep_configs

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI utils.""" """pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import json import json

View File

@@ -1,6 +1,7 @@
""" """
shared pytest fixtures shared pytest fixtures
""" """
import functools import functools
import importlib import importlib
import shutil import shutil

View File

@@ -1,6 +1,7 @@
""" """
Tests for the chat messages module Tests for the chat messages module
""" """
import unittest import unittest
import pytest import pytest

Some files were not shown because too many files have changed in this diff Show More