Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
1f5c0d3613 fix graph break for compile 2025-05-23 11:50:37 -04:00
Wing Lian
3ae0f7c08e make sure torch_compile is enabled with SAC 2025-05-23 11:15:44 -04:00
Wing Lian
5930c91a12 add support for SAC 2025-05-23 10:33:02 -04:00
4 changed files with 64 additions and 35 deletions

View File

@@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
# Keep max_num as a tensor instead of extracting to Python int
max_num = torch.max(attention_mask)
# Create a range tensor for comparison
range_tensor = torch.arange(
1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
)
# Vectorized approach - compare attention_mask with each value in range
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
# Sum along sequence dimension to get counts
counts = mask.sum(dim=1).to(dtype=torch.int32)
# Flatten and filter non-zero values
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
nonzero_mask = result != 0
return result[nonzero_mask]
@torch.jit.script

View File

@@ -521,6 +521,11 @@ def train(
"""
print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,

View File

@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -66,31 +66,32 @@ LOG = logging.getLogger(__name__)
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = []
if not cfg.pretraining_dataset:
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
# Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset
@@ -271,7 +272,7 @@ def load_tokenized_prepared_datasets(
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset."
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
)
if cfg.seed:

View File

@@ -182,6 +182,7 @@ class AxolotlInputConfig(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None
@@ -1079,6 +1080,19 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_activation_memory_budget_w_compile(cls, data):
if data.get("activation_memory_budget") is not None and not data.get(
"torch_compile"
):
LOG.warning(
"activation_memory_budget is enabled, but torch_compile is not set. "
"Automatically setting torch_compile to true."
)
data["torch_compile"] = True
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):