Compare commits
6 Commits
diffusion-
...
dump-confi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b594f18f6e | ||
|
|
700791deb9 | ||
|
|
d6d2cc673b | ||
|
|
83525f14a0 | ||
|
|
68c0e31fd1 | ||
|
|
22f930c658 |
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
API_KEY_FIELDS = {"comet_api_key"}
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
"""
|
"""
|
||||||
@@ -233,4 +235,15 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
|
cfg_to_log = {
|
||||||
|
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||||
|
for k, v in cfg.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"config:\n%s",
|
||||||
|
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||||
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -255,7 +254,6 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -33,7 +32,6 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -27,7 +26,6 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
@@ -545,8 +544,6 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
Reference in New Issue
Block a user