moving text art; adding sensitive value redaction + sorting
This commit is contained in:
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.2.0
|
rev: 7.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli.redaction import redact_sensitive_info
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
@@ -233,6 +234,10 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
LOG.info(f"cfg:\n{json.dumps(cfg, indent=2, default=str)}")
|
redacted_cfg = redact_sensitive_info(cfg)
|
||||||
|
LOG.info(
|
||||||
|
"config:\n%s",
|
||||||
|
json.dumps(redacted_cfg, indent=2, default=str, sort_keys=True),
|
||||||
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -255,7 +254,6 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -33,7 +32,6 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -27,7 +26,6 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
96
src/axolotl/cli/redaction.py
Normal file
96
src/axolotl/cli/redaction.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Utils for redaction of sensitive information in config."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# NOTE: Borrowed from the telemetry logic. Should be unified with it once merged.
|
||||||
|
WHITELIST_PATH = str(Path(__file__).parent / "redaction_whitelist.yaml")
|
||||||
|
|
||||||
|
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||||
|
WHITELIST = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Send org strings to lowercase since model names are case insensitive
|
||||||
|
WHITELIST["organizations"] = {org.lower() for org in WHITELIST["organizations"]}
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Need to keep these up to date with any config schema changes.
|
||||||
|
FIELDS_TO_REDACT = {
|
||||||
|
"base_model",
|
||||||
|
"tokenizer_config",
|
||||||
|
"base_model_config",
|
||||||
|
"pretraining_dataset", # NOTE: this field may be a string or a dictionary.
|
||||||
|
"resume_from_checkpoint",
|
||||||
|
"hub_model_id",
|
||||||
|
}
|
||||||
|
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||||
|
PATH_INDICATORS = {"path", "dir"}
|
||||||
|
|
||||||
|
|
||||||
|
def is_whitelisted(value: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model / dataset / etc. org is in whitelist.
|
||||||
|
|
||||||
|
This logic is borrowed from the telemetry logic. Should be unified with it once
|
||||||
|
merged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value for one of `FIELDS_WITH_ORGS` ("base_model", etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating whitelist membership.
|
||||||
|
"""
|
||||||
|
# NOTE: This membership-checking logic can be improved.
|
||||||
|
# What happens when a local model path matches a whitelisted org?
|
||||||
|
parts = value.split("/")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return False
|
||||||
|
org = parts[0]
|
||||||
|
whitelisted = org.lower() in WHITELIST["organizations"]
|
||||||
|
|
||||||
|
return whitelisted
|
||||||
|
|
||||||
|
|
||||||
|
def redact_sensitive_info(properties: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Redact properties to remove any paths, API keys, etc., so as to avoid collecting
|
||||||
|
private or personally identifiable information (PII).
|
||||||
|
|
||||||
|
This logic is borrowed from the telemetry logic. It can be unified with it once
|
||||||
|
merged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
properties: Dictionary of properties to redact.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Properties dictionary with redaction applied.
|
||||||
|
"""
|
||||||
|
if not properties:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def redact_value(value: Any, key: str = "") -> Any:
|
||||||
|
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||||
|
if isinstance(key, str) and isinstance(value, str):
|
||||||
|
# Other redaction special cases
|
||||||
|
if (
|
||||||
|
key in FIELDS_TO_REDACT
|
||||||
|
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||||
|
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||||
|
):
|
||||||
|
# Fields with whitelisted orgs don't need to be redacted
|
||||||
|
if not is_whitelisted(value):
|
||||||
|
return "[REDACTED]"
|
||||||
|
|
||||||
|
# Handle nested values
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: redact_value(v, k) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [redact_value(item) for item in value]
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Create new dict with redacted values
|
||||||
|
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||||
|
|
||||||
|
return redacted
|
||||||
17
src/axolotl/cli/redaction_whitelist.yaml
Normal file
17
src/axolotl/cli/redaction_whitelist.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
organizations:
|
||||||
|
- "axolotl-ai-co"
|
||||||
|
- "meta-llama"
|
||||||
|
- "huggingface"
|
||||||
|
- "nvidia"
|
||||||
|
- "facebook"
|
||||||
|
- "google"
|
||||||
|
- "microsoft"
|
||||||
|
- "deepseek-ai"
|
||||||
|
- "HuggingFaceTB"
|
||||||
|
- "mistralai"
|
||||||
|
- "Qwen"
|
||||||
|
- "unsloth"
|
||||||
|
- "NousResearch"
|
||||||
|
- "allenai"
|
||||||
|
- "amd"
|
||||||
|
- "tiiuae"
|
||||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
@@ -545,8 +544,6 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
Reference in New Issue
Block a user