Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
0b61cd445a v0.10.1 release
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-06-19 11:29:39 -04:00
14 changed files with 27 additions and 33 deletions

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.11.0.dev" __version__ = "0.10.1"

View File

@@ -7,6 +7,7 @@ from typing import Union
import yaml import yaml
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -23,6 +24,7 @@ def do_cli_preprocess(
cloud_config: Union[Path, str], cloud_config: Union[Path, str],
config: Union[Path, str], config: Union[Path, str],
) -> None: ) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config) cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg) cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file: with open(config, "r", encoding="utf-8") as file:
@@ -37,6 +39,7 @@ def do_cli_train(
cwd=None, cwd=None,
**kwargs, **kwargs,
) -> None: ) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config) cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg) cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file: with open(config, "r", encoding="utf-8") as file:
@@ -51,6 +54,7 @@ def do_cli_lm_eval(
cloud_config: Union[Path, str], cloud_config: Union[Path, str],
config: Union[Path, str], config: Union[Path, str],
) -> None: ) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config) cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg) cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file: with open(config, "r", encoding="utf-8") as file:

View File

@@ -28,8 +28,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__) LOG = get_logger(__name__)
API_KEY_FIELDS = {"comet_api_key"}
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
""" """
@@ -235,15 +233,4 @@ def load_cfg(
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
plugin_set_cfg(cfg) plugin_set_cfg(cfg)
cfg_to_log = {
k: "[REDACTED]" if k in API_KEY_FIELDS else v
for k, v in cfg.items()
if v is not None
}
LOG.info(
"config:\n%s",
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
)
return cfg return cfg

View File

@@ -9,6 +9,7 @@ from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
@@ -34,6 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
patch_optimized_env() patch_optimized_env()
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token() check_user_token()

View File

@@ -13,6 +13,7 @@ from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import ( from axolotl.utils.chat_templates import (
@@ -254,6 +255,7 @@ def do_cli(
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs) parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs) parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -20,7 +20,6 @@ from axolotl.cli.args import (
TrainerCliArgs, TrainerCliArgs,
VllmServeCliArgs, VllmServeCliArgs,
) )
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.sweeps import generate_sweep_configs from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
@@ -41,7 +40,6 @@ LOG = get_logger(__name__)
@click.version_option(version=axolotl.__version__, prog_name="axolotl") @click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
@cli.command() @cli.command()

View File

@@ -6,6 +6,7 @@ from typing import Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -22,6 +23,8 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
""" """
print_axolotl_text_art()
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True

View File

@@ -22,6 +22,7 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -193,6 +194,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -12,6 +12,7 @@ from dotenv import load_dotenv
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs from axolotl.cli.args import PreprocessCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -32,6 +33,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments. cli_args: Preprocessing-specific CLI arguments.
""" """
print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()

View File

@@ -7,6 +7,7 @@ from typing import Union
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -26,6 +27,7 @@ def do_quantize(
config (Union[Path, str]): The path to the config file config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments cli_args (dict): Additional command-line arguments
""" """
print_axolotl_text_art()
cfg = load_cfg(config) cfg = load_cfg(config)

View File

@@ -11,6 +11,7 @@ from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
@@ -34,6 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env() patch_optimized_env()
print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token() check_user_token()

View File

@@ -23,6 +23,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
@@ -544,6 +545,8 @@ def train(
Returns: Returns:
Tuple of (model, tokenizer) after training Tuple of (model, tokenizer) after training
""" """
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc. # Setup model, tokenizer, (causal or RLHF) trainer, etc.
( (
trainer, trainer,

View File

@@ -46,23 +46,16 @@ def get_current_device() -> int:
return 0 return 0
def init_distributed_state():
global distributed_state # pylint: disable=global-statement
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
def get_distributed_state() -> PartialState | None: def get_distributed_state() -> PartialState | None:
return distributed_state return distributed_state
def is_distributed() -> bool: def is_distributed() -> bool:
"""Check if distributed training is initialized.""" """Check if distributed training is initialized."""
init_distributed_state() global distributed_state # pylint: disable=global-statement
if distributed_state is None: if distributed_state is None:
return False timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return distributed_state.use_distributed and distributed_state.initialized return distributed_state.use_distributed and distributed_state.initialized

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -537,12 +537,6 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3: if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
init_distributed_state()
# If we don't assign this, it doesn't actually get set in the accelerate weakref # If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed) _ = HfTrainerDeepSpeedConfig(cfg.deepspeed)