Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
1f5c0d3613 fix graph break for compile 2025-05-23 11:50:37 -04:00
Wing Lian
3ae0f7c08e make sure torch_compile is enabled with SAC 2025-05-23 11:15:44 -04:00
Wing Lian
5930c91a12 add support for SAC 2025-05-23 10:33:02 -04:00
4 changed files with 64 additions and 35 deletions

View File

@@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script @torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item()) # Keep max_num as a tensor instead of extracting to Python int
batch_size, _ = attention_mask.shape max_num = torch.max(attention_mask)
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1): # Create a range tensor for comparison
mask = attention_mask == i range_tensor = torch.arange(
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) 1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
)
# Vectorized approach - compare attention_mask with each value in range
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
# Sum along sequence dimension to get counts
counts = mask.sum(dim=1).to(dtype=torch.int32)
# Flatten and filter non-zero values
result = counts.flatten() result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1) nonzero_mask = result != 0
return result[nonzero_indices] return result[nonzero_mask]
@torch.jit.script @torch.jit.script

View File

@@ -521,6 +521,11 @@ def train(
""" """
print_axolotl_text_art() print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc. # Setup model, tokenizer, (causal or RLHF) trainer, etc.
( (
trainer, trainer,

View File

@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
retry_on_request_exceptions, retry_on_request_exceptions,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -66,31 +66,32 @@ LOG = logging.getLogger(__name__)
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
if cfg.test_datasets: with zero_first(is_local_main_process()):
train_dataset, _, prompters = load_prepare_datasets( if cfg.test_datasets:
tokenizer, train_dataset, _, prompters = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
split="train", DEFAULT_DATASET_PREPARED_PATH,
processor=processor, split="train",
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
_, eval_dataset, _ = load_prepare_datasets( )
tokenizer, _, eval_dataset, _ = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
split="test", DEFAULT_DATASET_PREPARED_PATH,
processor=processor, split="test",
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
else: )
train_dataset, eval_dataset, prompters = load_prepare_datasets( else:
tokenizer, train_dataset, eval_dataset, prompters = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
processor=processor, DEFAULT_DATASET_PREPARED_PATH,
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
)
else: else:
# Load streaming dataset if pretraining_dataset is given # Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
@@ -271,7 +272,7 @@ def load_tokenized_prepared_datasets(
LOG.info("Loading raw datasets...") LOG.info("Loading raw datasets...")
if not cfg.is_preprocess: if not cfg.is_preprocess:
LOG.warning( LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset." "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
) )
if cfg.seed: if cfg.seed:

View File

@@ -182,6 +182,7 @@ class AxolotlInputConfig(
default=False default=False
) )
gradient_checkpointing_kwargs: dict[str, Any] | None = None gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = None
@@ -1079,6 +1080,19 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_activation_memory_budget_w_compile(cls, data):
if data.get("activation_memory_budget") is not None and not data.get(
"torch_compile"
):
LOG.warning(
"activation_memory_budget is enabled, but torch_compile is not set. "
"Automatically setting torch_compile to true."
)
data["torch_compile"] = True
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_npu_config(cls, data): def check_npu_config(cls, data):