Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
1f5c0d3613 fix graph break for compile 2025-05-23 11:50:37 -04:00
Wing Lian
3ae0f7c08e make sure torch_compile is enabled with SAC 2025-05-23 11:15:44 -04:00
Wing Lian
5930c91a12 add support for SAC 2025-05-23 10:33:02 -04:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
10 changed files with 78 additions and 78 deletions

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=90 * 60, timeout=90 * 60,
cpu=8.0, cpu=16.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG, volumes=VOLUME_CONFIG,
) )

View File

@@ -633,7 +633,9 @@ weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
adam_beta2: adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon: adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:

View File

@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2: if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon: if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm: if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3) adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset: if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = ( trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs self.cfg.precompute_ref_log_probs

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings import warnings
from contextlib import nullcontext
from typing import Any from typing import Any
import datasets import datasets
@@ -14,7 +13,7 @@ from accelerate.utils import (
broadcast_object_list, broadcast_object_list,
gather, gather,
gather_object, gather_object,
is_peft_model, is_peft_available,
) )
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from torch import nn from torch import nn
@@ -30,15 +29,13 @@ from transformers import (
TrainerCallback, TrainerCallback,
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.data_utils import ( from trl.data_utils import (
apply_chat_template, apply_chat_template,
is_conversational, is_conversational,
maybe_apply_chat_template, maybe_apply_chat_template,
) )
from trl.extras.profiling import profiling_context, profiling_decorator from trl.extras.profiling import profiling_context
from trl.import_utils import is_deepspeed_available
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -52,62 +49,12 @@ if is_peft_available():
# pylint: disable=unused-import # pylint: disable=unused-import
from peft import PeftConfig from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers""" """Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling""" """Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
}, },
) )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section # multi-modal section
image_size: int | tuple[int, int] | None = field( image_size: int | tuple[int, int] | None = field(

View File

@@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script @torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item()) # Keep max_num as a tensor instead of extracting to Python int
batch_size, _ = attention_mask.shape max_num = torch.max(attention_mask)
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1): # Create a range tensor for comparison
mask = attention_mask == i range_tensor = torch.arange(
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) 1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
)
# Vectorized approach - compare attention_mask with each value in range
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
# Sum along sequence dimension to get counts
counts = mask.sum(dim=1).to(dtype=torch.int32)
# Flatten and filter non-zero values
result = counts.flatten() result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1) nonzero_mask = result != 0
return result[nonzero_indices] return result[nonzero_mask]
@torch.jit.script @torch.jit.script

View File

@@ -521,6 +521,11 @@ def train(
""" """
print_axolotl_text_art() print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc. # Setup model, tokenizer, (causal or RLHF) trainer, etc.
( (
trainer, trainer,

View File

@@ -1,6 +1,7 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging import logging
import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow""" """Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
): ):
if is_main_process(): if is_main_process():
try: try:
with NamedTemporaryFile( if should_log_artifacts():
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" with NamedTemporaryFile(
) as temp_file: mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
copyfile(self.axolotl_config_path, temp_file.name) ) as temp_file:
mlflow.log_artifact(temp_file.name, artifact_path="") copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
LOG.info( LOG.info(
"The Axolotl config has been saved to the MLflow artifacts." "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -182,6 +182,7 @@ class AxolotlInputConfig(
default=False default=False
) )
gradient_checkpointing_kwargs: dict[str, Any] | None = None gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = None
@@ -1079,6 +1080,19 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_activation_memory_budget_w_compile(cls, data):
if data.get("activation_memory_budget") is not None and not data.get(
"torch_compile"
):
LOG.warning(
"activation_memory_budget is enabled, but torch_compile is not set. "
"Automatically setting torch_compile to true."
)
data["torch_compile"] = True
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_npu_config(cls, data): def check_npu_config(cls, data):

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
""" """
) )
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
[1, 2], [1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally: finally:
recursive_kill(vllm_process) recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
[1, 2], [1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,