Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
1f5c0d3613 fix graph break for compile 2025-05-23 11:50:37 -04:00
Wing Lian
3ae0f7c08e make sure torch_compile is enabled with SAC 2025-05-23 11:15:44 -04:00
Wing Lian
5930c91a12 add support for SAC 2025-05-23 10:33:02 -04:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
10 changed files with 78 additions and 78 deletions

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=8.0,
cpu=16.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -633,7 +633,9 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
@@ -14,7 +13,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
is_peft_available,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -30,15 +29,13 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.extras.profiling import profiling_context
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -52,62 +49,12 @@ if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
# Keep max_num as a tensor instead of extracting to Python int
max_num = torch.max(attention_mask)
# Create a range tensor for comparison
range_tensor = torch.arange(
1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
)
# Vectorized approach - compare attention_mask with each value in range
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
# Sum along sequence dimension to get counts
counts = mask.sum(dim=1).to(dtype=torch.int32)
# Flatten and filter non-zero values
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
nonzero_mask = result != 0
return result[nonzero_mask]
@torch.jit.script

View File

@@ -521,6 +521,11 @@ def train(
"""
print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,

View File

@@ -1,6 +1,7 @@
"""MLFlow module for trainer callbacks"""
import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
if should_log_artifacts():
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -182,6 +182,7 @@ class AxolotlInputConfig(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None
@@ -1079,6 +1080,19 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_activation_memory_budget_w_compile(cls, data):
if data.get("activation_memory_budget") is not None and not data.get(
"torch_compile"
):
LOG.warning(
"activation_memory_budget is enabled, but torch_compile is not set. "
"Automatically setting torch_compile to true."
)
data["torch_compile"] = True
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,