Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
@@ -38,6 +38,7 @@ wandb_log_model:
|
|||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|
||||||
optimizer: adamw_8bit
|
optimizer: adamw_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Various checks for Axolotl CLI."""
|
"""Various checks for Axolotl CLI."""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_accelerate_default_config() -> None:
|
def check_accelerate_default_config() -> None:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Configuration loading and processing."""
|
"""Configuration loading and processing."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
|
|||||||
validate_config,
|
validate_config,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(yaml_files) == 1:
|
if len(yaml_files) == 1:
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
|
||||||
return str(yaml_files[0])
|
return str(yaml_files[0])
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
LOG.info("Choose a YAML file:")
|
||||||
for idx, file in enumerate(yaml_files):
|
for idx, file in enumerate(yaml_files):
|
||||||
print(f"{idx + 1}. {file}")
|
LOG.info(f"{idx + 1}. {file}")
|
||||||
|
|
||||||
chosen_file = None
|
chosen_file = None
|
||||||
while chosen_file is None:
|
while chosen_file is None:
|
||||||
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
|
|||||||
if 1 <= choice <= len(yaml_files):
|
if 1 <= choice <= len(yaml_files):
|
||||||
chosen_file = str(yaml_files[choice - 1])
|
chosen_file = str(yaml_files[choice - 1])
|
||||||
else:
|
else:
|
||||||
print("Invalid choice. Please choose a number from the list.")
|
LOG.info("Invalid choice. Please choose a number from the list.")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("Invalid input. Please enter a number.")
|
LOG.info("Invalid input. Please enter a number.")
|
||||||
|
|
||||||
return chosen_file
|
return chosen_file
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""CLI to run evaluation on a model."""
|
"""CLI to run evaluation on a model."""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
|
|||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import patch_optimized_env
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""CLI to run inference on a trained model."""
|
"""CLI to run inference on a trained model."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
|
|||||||
get_chat_template_from_config,
|
get_chat_template_from_config,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> str:
|
def get_multi_line_input() -> str:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import subprocess # nosec B404
|
import subprocess # nosec B404
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -31,8 +30,11 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import patch_optimized_env
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
@@ -177,7 +179,7 @@ def train(
|
|||||||
|
|
||||||
do_cli(config=cfg_file, **kwargs)
|
do_cli(config=cfg_file, **kwargs)
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
if not sweep:
|
if not sweep:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""CLI to merge a trained LoRA into a base model."""
|
"""CLI to merge a trained LoRA into a base model."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -13,8 +12,9 @@ from axolotl.cli.art import print_axolotl_text_art
|
|||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
|
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -27,8 +26,9 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""CLI to run preprocessing of a dataset."""
|
"""CLI to run preprocessing of a dataset."""
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
CLI to post-training quantize a model using torchao
|
CLI to post-training quantize a model using torchao
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -11,9 +10,10 @@ from transformers import AutoModelForCausalLM
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_quantize(
|
def do_quantize(
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
|
|||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import concurrent.futures
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
@@ -23,8 +22,9 @@ from transformers import (
|
|||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
from axolotl.loaders.model import ModelLoader
|
from axolotl.loaders.model import ModelLoader
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def strip_optional_type(field_type: type | str | None):
|
def strip_optional_type(field_type: type | str | None):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Dataset loading utilities."""
|
"""Dataset loading utilities."""
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -14,10 +13,11 @@ from axolotl.loaders import load_processor, load_tokenizer
|
|||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ class Messages(BaseModel):
|
|||||||
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||||
]
|
]
|
||||||
if new_pending_inputs != pending_input_ids:
|
if new_pending_inputs != pending_input_ids:
|
||||||
# logging.warning("tokenization mismatch from concatenation.")
|
|
||||||
pending_input_ids = new_pending_inputs
|
pending_input_ids = new_pending_inputs
|
||||||
input_ids.extend(pending_input_ids)
|
input_ids.extend(pending_input_ids)
|
||||||
if pending_weight:
|
if pending_weight:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import abc
|
|||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -88,6 +87,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -95,7 +95,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@@ -34,9 +33,10 @@ from axolotl.core.trainers.utils import (
|
|||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
|
|||||||
AxolotlGRPOTrainer,
|
AxolotlGRPOTrainer,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GRPOStrategy:
|
class GRPOStrategy:
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
"""Module for Axolotl trainer optimizer mixin"""
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
class OptimizerMixin(Trainer):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
|
|||||||
TODO: Remove when upstream added PR to release
|
TODO: Remove when upstream added PR to release
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
|
|||||||
from transformers.trainer_pt_utils import set_rng_state_for_device
|
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||||
from transformers.training_args import ParallelMode
|
from transformers.training_args import ParallelMode
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RngLoaderMixin(Trainer):
|
class RngLoaderMixin(Trainer):
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
"""Module for Axolotl trainer scheduler mixin"""
|
"""Module for Axolotl trainer scheduler mixin"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
|
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
RexLR,
|
RexLR,
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
|
|||||||
get_cosine_schedule_with_warmup_decay_constant,
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
class SchedulerMixin(Trainer):
|
||||||
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
|
|||||||
self.lr_scheduler = RexLR(
|
self.lr_scheduler = RexLR(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
max_lr=self.args.learning_rate,
|
max_lr=self.args.learning_rate,
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
min_lr=0 if not use_cosine_min_lr else (
|
||||||
|
self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
total_steps=num_training_steps,
|
total_steps=num_training_steps,
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
)
|
)
|
||||||
elif use_cosine_quadratic:
|
elif use_cosine_quadratic:
|
||||||
if use_cosine_min_lr:
|
if use_cosine_min_lr:
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
LOG.warning(
|
||||||
|
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer,
|
optimizer,
|
||||||
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning(
|
||||||
|
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
if use_cosine_min_lr:
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
LOG.warning(
|
||||||
|
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
return self.lr_scheduler # type: ignore
|
return self.lr_scheduler # type: ignore
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""Module containing Dataset functionality"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||||
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
|||||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||||
# the collators later on to pad the datasets
|
# the collators later on to pad the datasets
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenizedPromptDataset(Dataset):
|
class TokenizedPromptDataset(Dataset):
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
@@ -31,6 +30,9 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||||||
from transformers import PreTrainedModel, Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
@@ -331,12 +333,12 @@ class PluginManager:
|
|||||||
ImportError: If the plugin module cannot be imported.
|
ImportError: If the plugin module cannot be imported.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logging.info(f"Attempting to load plugin: {plugin_name}")
|
LOG.info(f"Attempting to load plugin: {plugin_name}")
|
||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins[plugin_name] = plugin
|
||||||
logging.info(f"Plugin loaded successfully: {plugin_name}")
|
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
|
||||||
def get_input_args(self) -> list[str]:
|
def get_input_args(self) -> list[str]:
|
||||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||||
|
|||||||
@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
|
|||||||
from Apple's ML team.
|
from Apple's ML team.
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
cce_patch,
|
cce_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process(use_environ=True):
|
LOG.info(
|
||||||
LOG.info(
|
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
cce_patch(cfg.model_config_type)
|
||||||
|
|||||||
@@ -15,12 +15,13 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling Cut Cross Entropy input arguments.
|
Module for handling Cut Cross Entropy input arguments.
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CutCrossEntropyArgs(BaseModel):
|
class CutCrossEntropyArgs(BaseModel):
|
||||||
|
|||||||
@@ -2,15 +2,15 @@
|
|||||||
Grokfast plugin for Axolotl
|
Grokfast plugin for Axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..base import BasePlugin
|
from ..base import BasePlugin
|
||||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
from .optimizer import gradfilter_ema
|
from .optimizer import gradfilter_ema
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GrokfastCallbackHandler(TrainerCallback):
|
class GrokfastCallbackHandler(TrainerCallback):
|
||||||
|
|||||||
@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
|||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
from .utils import patch_with_compile_disable
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
|
|||||||
kwargs["geglu"] = cfg.liger_glu_activation
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
elif "swiglu" in liger_fn_sig.parameters:
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
if is_main_process(use_environ=True):
|
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
||||||
LOG.info(
|
|
||||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
|
||||||
)
|
|
||||||
apply_liger_fn(**kwargs)
|
apply_liger_fn(**kwargs)
|
||||||
elif cfg.model_config_type == "jamba":
|
elif cfg.model_config_type == "jamba":
|
||||||
from transformers.models.jamba import modeling_jamba
|
from transformers.models.jamba import modeling_jamba
|
||||||
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_rope:
|
if cfg.liger_rope:
|
||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
@@ -186,6 +182,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
swiglu=cfg.liger_glu_activation,
|
swiglu=cfg.liger_glu_activation,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,12 +15,13 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LigerArgs(BaseModel):
|
class LigerArgs(BaseModel):
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
|
|||||||
by maintaining masks for zero weights during training.
|
by maintaining masks for zero weights during training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
|
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
|
|||||||
from transformers.training_args import TrainingArguments
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
P = ParamSpec("P") # Params for generic function signatures
|
P = ParamSpec("P") # Params for generic function signatures
|
||||||
R = TypeVar("R") # Return type for generic function signatures
|
R = TypeVar("R") # Return type for generic function signatures
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMCompressorCallbackHandler(TrainerCallback):
|
class LLMCompressorCallbackHandler(TrainerCallback):
|
||||||
|
|||||||
@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
|
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
|
||||||
unfrozen_parameters = {}
|
unfrozen_parameters = {}
|
||||||
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
|
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
|
||||||
|
|
||||||
if not snr_data:
|
if not snr_data:
|
||||||
try:
|
try:
|
||||||
snr_data = requests.get(snr_url, timeout=60).json()
|
snr_data = requests.get(snr_url, timeout=60).json()
|
||||||
except requests.exceptions.RequestException as exc:
|
except requests.exceptions.RequestException as exc:
|
||||||
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
|
LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
|
||||||
return
|
return
|
||||||
# also catch json parsing errors
|
# also catch json parsing errors
|
||||||
except json.JSONDecodeError as exc:
|
except json.JSONDecodeError as exc:
|
||||||
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
|
LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
unfrozen_parameters = _generate_unfrozen_params_yaml(
|
unfrozen_parameters = _generate_unfrozen_params_yaml(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Adapter loading functionality, including LoRA / QLoRA and associated utils"""
|
"""Adapter loading functionality, including LoRA / QLoRA and associated utils"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -21,8 +20,9 @@ from transformers import PreTrainedModel
|
|||||||
|
|
||||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ models.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@@ -47,10 +46,11 @@ from axolotl.utils.distributed import (
|
|||||||
get_device_count,
|
get_device_count,
|
||||||
get_device_type,
|
get_device_type,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Applies pre- and post-model load patches for various fixes and optimizations.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
@@ -17,8 +16,9 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Processor loading functionality for multi-modal models"""
|
"""Processor loading functionality for multi-modal models"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -10,8 +9,9 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Tokenizer loading functionality and associated utils"""
|
"""Tokenizer loading functionality and associated utils"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -19,8 +18,9 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Utilities for axolotl.loaders module"""
|
"""Utilities for axolotl.loaders module"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
@@ -9,8 +8,9 @@ import torch
|
|||||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_module_class_from_name(
|
def get_module_class_from_name(
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Flash attention monkey patch for cerebras btlm model
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,7 +10,9 @@ from accelerate import init_empty_weights
|
|||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
|||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
import shutil
|
import shutil
|
||||||
@@ -32,11 +31,13 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
# Setup logger
|
# Setup logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DiskOffloadManager:
|
class DiskOffloadManager:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -25,6 +24,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
@@ -41,7 +41,7 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_xformers_available() -> bool:
|
def is_xformers_available() -> bool:
|
||||||
@@ -612,9 +612,10 @@ def generate_qkv(
|
|||||||
q, query_padding_mask
|
q, query_padding_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
def output_pad_fn(output_unpad):
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
return pad_input( # noqa: E731
|
||||||
)
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
@@ -627,9 +628,10 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
max_seqlen_q = seqlen_q
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
def output_pad_fn(output_unpad):
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
return rearrange( # noqa: E731
|
||||||
)
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -11,10 +10,14 @@ import torch.nn.functional as F
|
|||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error("xformers not found! Please install it before trying to use it.")
|
LOG.error("xformers not found! Please install it before trying to use it.")
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_attention():
|
def hijack_llama_attention():
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import types
|
|||||||
from typing import Generator, Tuple, Type
|
from typing import Generator, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
@@ -20,6 +19,7 @@ from axolotl.kernels.lora import (
|
|||||||
)
|
)
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -28,8 +27,9 @@ from transformers.models.mistral.modeling_mistral import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_attn_with_flash_attn(
|
def replace_mistral_attn_with_flash_attn(
|
||||||
@@ -359,9 +359,10 @@ def generate_qkv(
|
|||||||
q, query_padding_mask
|
q, query_padding_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
def output_pad_fn(output_unpad):
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
return pad_input( # noqa: E731
|
||||||
)
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
@@ -374,9 +375,10 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
max_seqlen_q = seqlen_q
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
def output_pad_fn(output_unpad):
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
return rearrange( # noqa: E731
|
||||||
)
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ Patch prepare_model_for_kbit_training to not upcast everything
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
import peft
|
import peft
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
ORIGINAL_PREPARE_CODE = """
|
ORIGINAL_PREPARE_CODE = """
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os.path
|
import os.path
|
||||||
import shutil
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -27,8 +26,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.relora")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -32,11 +32,11 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
monkeypatch for Trainer _get_learning_rate method
|
monkeypatch for Trainer _get_learning_rate method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
|
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ allow adding additional kwargs to Accelerator init
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
ORIGINAL_TRAINER_CODE = """
|
||||||
# create accelerator object
|
# create accelerator object
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ fix for FSDP2 evals when using torch.compile
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
ORIGINAL_TRAINER_CODE = """
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ fix for FSDP optimizer save in trainer w 4.47.0
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
ORIGINAL_TRAINER_CODE = """
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,14 @@
|
|||||||
see https://github.com/huggingface/transformers/pull/35834
|
see https://github.com/huggingface/transformers/pull/35834
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fixed_fa_peft_integration_check(
|
def fixed_fa_peft_integration_check(
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
|||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
ORIGINAL_QKV_CODE = """
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
@@ -133,7 +133,7 @@ def patch_self_attn_lora():
|
|||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
self_attn_lora_patched = True
|
self_attn_lora_patched = True
|
||||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
LOG.info("patching unsloth attn lora")
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
)
|
)
|
||||||
@@ -153,7 +153,7 @@ def integrate_rope_embeddings():
|
|||||||
):
|
):
|
||||||
return fast_rope_embedding(q, k, cos, sin)
|
return fast_rope_embedding(q, k, cos, sin)
|
||||||
|
|
||||||
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
|
LOG.info("patching unsloth RoPE embeddings")
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
@@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
|
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
|
||||||
if cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_o:
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
@@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_o = apply_lora_o
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_o = original_apply_o
|
layer.self_attn.apply_o = original_apply_o
|
||||||
LOG.warning(
|
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
|
||||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_unsloth_layernorm():
|
def patch_unsloth_layernorm():
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -10,7 +9,9 @@ from torch import Tensor
|
|||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
from transformers.image_utils import load_image
|
from transformers.image_utils import load_image
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProcessingStrategy:
|
class ProcessingStrategy:
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ module for base dataset transform strategies
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, cfg, module_base=None, **kwargs):
|
def load(strategy, cfg, module_base=None, **kwargs):
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
Bradley-Terry model with chat template prompt strategy.
|
Bradley-Terry model with chat template prompt strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import (
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
@@ -10,10 +9,11 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
|
LOG = get_logger(__name__)
|
||||||
LOG.setLevel(logging.INFO)
|
LOG.setLevel("INFO")
|
||||||
|
|
||||||
|
|
||||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||||
@@ -44,7 +44,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||||
@@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
HF Chat Templates prompt strategy
|
HF Chat Templates prompt strategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Set, Union
|
from typing import Any, Dict, List, Set, Union
|
||||||
|
|
||||||
@@ -13,11 +12,12 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
LOG.setLevel(logging.INFO)
|
LOG.setLevel("INFO")
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
@@ -378,7 +378,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
images=images,
|
images=images,
|
||||||
)
|
)
|
||||||
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
|
tokenized_res = self.prompter.build_prompt(
|
||||||
|
turns, images=images
|
||||||
|
) # type: ignore
|
||||||
tokenized_prompt = {}
|
tokenized_prompt = {}
|
||||||
if isinstance(tokenized_res, list):
|
if isinstance(tokenized_res, list):
|
||||||
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
||||||
@@ -555,8 +557,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
and turns[0].get("role") == "system"
|
and turns[0].get("role") == "system"
|
||||||
and (
|
and (
|
||||||
"mistral" in self.tokenizer.name_or_path.lower()
|
"mistral" in self.tokenizer.name_or_path.lower()
|
||||||
# gemma3 uses gemma tokenizer
|
or "gemma"
|
||||||
or "gemma" in self.tokenizer.name_or_path.lower()
|
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|||||||
@@ -24,12 +24,14 @@ For a custom system message, the first "from" can be "system" (followed by alter
|
|||||||
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Generator, List, Sequence
|
from typing import Generator, List, Sequence
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -129,7 +131,7 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
if cur_len < self.sequence_len:
|
if cur_len < self.sequence_len:
|
||||||
if cur_len != total_len:
|
if cur_len != total_len:
|
||||||
target[:] = IGNORE_TOKEN_ID
|
target[:] = IGNORE_TOKEN_ID
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||||
f" (ignored)"
|
f" (ignored)"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg, processor=None):
|
def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter
|
from axolotl.prompters import AlpacaPrompter
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
|
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Generator, List, Tuple
|
from typing import Generator, List, Tuple
|
||||||
|
|
||||||
@@ -10,8 +9,9 @@ from axolotl.prompt_tokenizers import (
|
|||||||
parse_tokenized_to_result,
|
parse_tokenized_to_result,
|
||||||
tokenize_prompt_default,
|
tokenize_prompt_default,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""Module containing prompters"""
|
"""Module containing prompters"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -37,6 +36,7 @@ from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContext
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer(
|
def setup_model_and_tokenizer(
|
||||||
@@ -64,9 +64,7 @@ def setup_model_and_tokenizer(
|
|||||||
`None`), and processor (if multimodal, else `None`).
|
`None`), and processor (if multimodal, else `None`).
|
||||||
"""
|
"""
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
# Load processor for multimodal models if needed
|
# Load processor for multimodal models if needed
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EvalFirstStepCallback(
|
class EvalFirstStepCallback(
|
||||||
@@ -753,7 +753,14 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
|||||||
].append(pred_step_text)
|
].append(pred_step_text)
|
||||||
row_index += 1
|
row_index += 1
|
||||||
if logger == "wandb":
|
if logger == "wandb":
|
||||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
|
# type: ignore[attr-defined]
|
||||||
|
wandb.run.log(
|
||||||
|
{
|
||||||
|
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
|
||||||
|
table_data
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
elif logger == "mlflow" and is_mlflow_available():
|
elif logger == "mlflow" and is_mlflow_available():
|
||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
"""Comet module for trainer callbacks"""
|
"""Comet module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import comet_ml
|
import comet_ml
|
||||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||||
|
|||||||
@@ -6,17 +6,18 @@ Arxiv: https://arxiv.org/abs/2403.17919
|
|||||||
License: Apache 2.0
|
License: Apache 2.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
from axolotl.core.trainer_builder import AxolotlTrainer
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
@@ -10,11 +9,12 @@ import mlflow
|
|||||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def should_log_artifacts() -> bool:
|
def should_log_artifacts() -> bool:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""QAT Callback for HF Causal Trainer"""
|
"""QAT Callback for HF Causal Trainer"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -8,9 +7,10 @@ from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
|||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.quantization import QATConfig
|
from axolotl.utils.schemas.quantization import QATConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def toggle_fake_quant(mod: nn.Module, enable: bool):
|
def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,11 +1,11 @@
|
|||||||
"""Module for wandb utilities"""
|
"""Module for wandb utilities"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
COMET_ENV_MAPPING_OVERRIDE = {
|
COMET_ENV_MAPPING_OVERRIDE = {
|
||||||
"comet_mode": "COMET_START_MODE",
|
"comet_mode": "COMET_START_MODE",
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for working with config dicts"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -15,13 +14,14 @@ from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING
|
|||||||
from axolotl.loaders.utils import load_model_config
|
from axolotl.loaders.utils import load_model_config
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
def choose_device(cfg):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""data handling specific to pretraining"""
|
"""data handling specific to pretraining"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
@@ -11,10 +10,11 @@ from torch.utils.data import RandomSampler
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Union
|
||||||
@@ -18,9 +17,10 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_
|
|||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_path(ds_hash, cfg):
|
def _get_path(ds_hash, cfg):
|
||||||
@@ -217,7 +217,7 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "train"
|
+ "train"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ str(seed)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
to_hash_test = (
|
to_hash_test = (
|
||||||
train_dataset._fingerprint # pylint: disable=protected-access
|
train_dataset._fingerprint # pylint: disable=protected-access
|
||||||
@@ -226,7 +226,7 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "test"
|
+ "test"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ str(seed)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""data handling specific to SFT"""
|
"""data handling specific to SFT"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -54,12 +53,13 @@ from axolotl.utils.data.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
@@ -182,10 +182,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
|||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
@@ -331,12 +330,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
dataset = datasets[0]
|
dataset = datasets[0]
|
||||||
else:
|
else:
|
||||||
LOG.info("merging datasets")
|
LOG.info("Merging datasets...")
|
||||||
dataset = concatenate_datasets(datasets)
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
LOG.debug("shuffle merged datasets")
|
LOG.debug("Shuffling merged datasets...")
|
||||||
dataset = dataset.shuffle(seed=seed)
|
dataset = dataset.shuffle(seed=seed)
|
||||||
else:
|
else:
|
||||||
LOG.debug("NOT shuffling merged datasets")
|
LOG.debug("NOT shuffling merged datasets")
|
||||||
@@ -426,7 +425,7 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "train"
|
+ "train"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ str(seed)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
to_hash_test = (
|
to_hash_test = (
|
||||||
dataset._fingerprint # pylint: disable=protected-access
|
dataset._fingerprint # pylint: disable=protected-access
|
||||||
@@ -435,7 +434,7 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "test"
|
+ "test"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ str(seed)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -12,10 +11,11 @@ import requests
|
|||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
from axolotl.utils.trainer import drop_long_seq
|
from axolotl.utils.trainer import drop_long_seq
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetryStrategy(Enum):
|
class RetryStrategy(Enum):
|
||||||
|
|||||||
62
src/axolotl/utils/logging.py
Normal file
62
src/axolotl/utils/logging.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
logging helpers to only log on main process
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
# Adapted from Accelerate
|
||||||
|
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py
|
||||||
|
|
||||||
|
|
||||||
|
class MultiProcessAdapter(logging.LoggerAdapter):
|
||||||
|
"""
|
||||||
|
logger adapter for distributed logging, specifically to only log on main process
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger, use_environ=False, extra=None):
|
||||||
|
super().__init__(logger, extra)
|
||||||
|
self.use_environ = use_environ
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_log(main_process_only, use_environ=False):
|
||||||
|
return not main_process_only or (
|
||||||
|
main_process_only and is_main_process(use_environ=use_environ)
|
||||||
|
)
|
||||||
|
|
||||||
|
def log(self, level, msg, *args, **kwargs):
|
||||||
|
use_environ = kwargs.pop("use_environ", self.use_environ)
|
||||||
|
main_process_only = kwargs.pop("main_process_only", True)
|
||||||
|
kwargs.setdefault("stacklevel", 2)
|
||||||
|
|
||||||
|
if self.isEnabledFor(level) and self._should_log(
|
||||||
|
main_process_only, use_environ=use_environ
|
||||||
|
):
|
||||||
|
msg, kwargs = self.process(msg, kwargs)
|
||||||
|
self.logger.log(level, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=10)
|
||||||
|
def warning_once(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
|
||||||
|
|
||||||
|
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
|
||||||
|
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
|
||||||
|
switch to another type of cache that includes the caller frame information in the hashing function.
|
||||||
|
"""
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(
|
||||||
|
name: str, log_level: str | None = None, use_environ: bool = False
|
||||||
|
) -> MultiProcessAdapter:
|
||||||
|
if log_level is None:
|
||||||
|
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if log_level is not None:
|
||||||
|
logger.setLevel(log_level.upper())
|
||||||
|
logger.root.setLevel(log_level.upper())
|
||||||
|
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})
|
||||||
@@ -2,8 +2,6 @@
|
|||||||
Utilities for quantization including QAT and PTQ using torchao.
|
Utilities for quantization including QAT and PTQ using torchao.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
@@ -25,8 +23,6 @@ from torchao.quantization.quant_api import (
|
|||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchIntDType
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ptq_config(
|
def get_ptq_config(
|
||||||
weight_dtype: TorchIntDType,
|
weight_dtype: TorchIntDType,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
|
|||||||
into fixed-capacity batches to optimize memory usage and training throughput.
|
into fixed-capacity batches to optimize memory usage and training throughput.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from multiprocessing import cpu_count, get_context
|
from multiprocessing import cpu_count, get_context
|
||||||
@@ -14,9 +13,9 @@ import numpy as np
|
|||||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||||
|
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
LOG.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
@numba.njit
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
@@ -18,6 +17,7 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
@@ -49,7 +49,7 @@ from axolotl.utils.schemas.training import HyperparametersConfig
|
|||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
from axolotl.utils.schemas.vllm import VllmConfig
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Pydantic models for deprecated and remapped configuration parameters"""
|
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class ChatTemplate(str, Enum):
|
|||||||
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
|
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
|
||||||
aya = "aya" # pylint: disable=invalid-name
|
aya = "aya" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
"""Custom supported optimizers"""
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Pydantic models for Axolotl integrations"""
|
"""Pydantic models for Axolotl integrations"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MLFlowConfig(BaseModel):
|
class MLFlowConfig(BaseModel):
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Pydantic models for model input / output, etc. configuration"""
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
|
||||||
class ModelInputConfig(BaseModel):
|
class ModelInputConfig(BaseModel):
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Pydantic models for training hyperparameters"""
|
"""Pydantic models for training hyperparameters"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from transformers import SchedulerType
|
from transformers import SchedulerType
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LrGroup(BaseModel):
|
class LrGroup(BaseModel):
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Utilities for Axolotl Pydantic models"""
|
"""Utilities for Axolotl Pydantic models"""
|
||||||
|
|
||||||
import logging
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(
|
def check_dataset_labels(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from axolotl.utils.distributed import reduce_and_broadcast
|
|||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -402,7 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
.apply(len)
|
.apply(len)
|
||||||
.values
|
.values
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
|
LOG.debug(f"total_num_tokens: {total_num_tokens:_}")
|
||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
@@ -420,10 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||||
.sum()
|
.sum()
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`")
|
||||||
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
|
|
||||||
main_process_only=True,
|
|
||||||
)
|
|
||||||
if update:
|
if update:
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
@@ -448,8 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
* cfg.sequence_parallel_degree
|
* cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||||
main_process_only=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if cfg.flash_attention and not cfg.multipack_real_batches:
|
if cfg.flash_attention and not cfg.multipack_real_batches:
|
||||||
@@ -478,7 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}")
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
@@ -500,10 +496,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
if update:
|
if update:
|
||||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||||
LOG.debug(
|
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
|
||||||
main_process_only=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(
|
||||||
@@ -513,7 +506,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}")
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -14,10 +13,11 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu eval
|
E2E tests for multigpu eval
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_tensorboard
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -13,10 +12,11 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard
|
from tests.e2e.utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -15,10 +14,11 @@ from packaging import version
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu qwen2
|
E2E tests for multigpu qwen2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu post-training use Ray Train
|
E2E tests for multigpu post-training use Ray Train
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -11,10 +10,11 @@ import yaml
|
|||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = get_logger(__name__)
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multipack fft llama using 4d attention masks
|
E2E tests for multipack fft llama using 4d attention masks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for falcon
|
E2E tests for falcon
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for llama w/ S2 attn
|
E2E tests for llama w/ S2 attn
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = get_logger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user