Compare commits

...

4 Commits

Author SHA1 Message Date
Dan Saunders
b594f18f6e just redact api keys 2025-06-24 14:19:00 -04:00
Dan Saunders
700791deb9 Merge branch 'main' into dump-config 2025-06-23 09:46:08 -04:00
Dan Saunders
d6d2cc673b remove none-valued config before dumping 2025-06-23 13:35:53 +00:00
Dan Saunders
1d8f500709 deepspeed fix (#2820) 2025-06-23 09:07:57 -04:00
5 changed files with 26 additions and 120 deletions

View File

@@ -13,7 +13,6 @@ import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.redaction import redact_sensitive_info
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
@@ -29,6 +28,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
API_KEY_FIELDS = {"comet_api_key"}
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -234,10 +235,15 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
redacted_cfg = redact_sensitive_info(cfg)
cfg_to_log = {
k: "[REDACTED]" if k in API_KEY_FIELDS else v
for k, v in cfg.items()
if v is not None
}
LOG.info(
"config:\n%s",
json.dumps(redacted_cfg, indent=2, default=str, sort_keys=True),
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
)
return cfg

View File

@@ -1,96 +0,0 @@
"""Utils for redaction of sensitive information in config."""
from pathlib import Path
from typing import Any
import yaml
# NOTE: Borrowed from the telemetry logic. Should be unified with it once merged.
WHITELIST_PATH = str(Path(__file__).parent / "redaction_whitelist.yaml")
with open(WHITELIST_PATH, encoding="utf-8") as f:
WHITELIST = yaml.safe_load(f)
# Send org strings to lowercase since model names are case insensitive
WHITELIST["organizations"] = {org.lower() for org in WHITELIST["organizations"]}
# NOTE: Need to keep these up to date with any config schema changes.
FIELDS_TO_REDACT = {
"base_model",
"tokenizer_config",
"base_model_config",
"pretraining_dataset", # NOTE: this field may be a string or a dictionary.
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
def is_whitelisted(value: str) -> bool:
"""
Check if model / dataset / etc. org is in whitelist.
This logic is borrowed from the telemetry logic. Should be unified with it once
merged.
Args:
value: Value for one of `FIELDS_WITH_ORGS` ("base_model", etc.).
Returns:
Boolean indicating whitelist membership.
"""
# NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org?
parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in WHITELIST["organizations"]
return whitelisted
def redact_sensitive_info(properties: dict[str, Any]) -> dict[str, Any]:
"""
Redact properties to remove any paths, API keys, etc., so as to avoid collecting
private or personally identifiable information (PII).
This logic is borrowed from the telemetry logic. It can be unified with it once
merged.
Args:
properties: Dictionary of properties to redact.
Returns:
Properties dictionary with redaction applied.
"""
if not properties:
return {}
def redact_value(value: Any, key: str = "") -> Any:
"""Recursively sanitize values, redacting those with path-like keys"""
if isinstance(key, str) and isinstance(value, str):
# Other redaction special cases
if (
key in FIELDS_TO_REDACT
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
):
# Fields with whitelisted orgs don't need to be redacted
if not is_whitelisted(value):
return "[REDACTED]"
# Handle nested values
if isinstance(value, dict):
return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list):
return [redact_value(item) for item in value]
return value
# Create new dict with redacted values
redacted = {k: redact_value(v, k) for k, v in properties.items()}
return redacted

View File

@@ -1,17 +0,0 @@
organizations:
- "axolotl-ai-co"
- "meta-llama"
- "huggingface"
- "nvidia"
- "facebook"
- "google"
- "microsoft"
- "deepseek-ai"
- "HuggingFaceTB"
- "mistralai"
- "Qwen"
- "unsloth"
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"

View File

@@ -46,16 +46,23 @@ def get_current_device() -> int:
return 0
def init_distributed_state():
global distributed_state # pylint: disable=global-statement
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
def get_distributed_state() -> PartialState | None:
return distributed_state
def is_distributed() -> bool:
"""Check if distributed training is initialized."""
global distributed_state # pylint: disable=global-statement
init_distributed_state()
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return False
return distributed_state.use_distributed and distributed_state.initialized

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -537,6 +537,12 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
init_distributed_state()
# If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)