Compare commits

..

1 Commits

Author SHA1 Message Date
NanoCode012
348409c2ff fix: num_items_in_batch wrong type in kd trainer loss 2025-05-20 16:56:24 +07:00
4 changed files with 11 additions and 36 deletions

View File

@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if num_items_in_batch is None:
num_items_in_batch = -1
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,

View File

@@ -16,24 +16,15 @@ from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
# Keep max_num as a tensor instead of extracting to Python int
max_num = torch.max(attention_mask)
# Create a range tensor for comparison
range_tensor = torch.arange(
1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
)
# Vectorized approach - compare attention_mask with each value in range
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
# Sum along sequence dimension to get counts
counts = mask.sum(dim=1).to(dtype=torch.int32)
# Flatten and filter non-zero values
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
result = counts.flatten()
nonzero_mask = result != 0
return result[nonzero_mask]
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
@torch.jit.script

View File

@@ -521,11 +521,6 @@ def train(
"""
print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,

View File

@@ -182,7 +182,6 @@ class AxolotlInputConfig(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None
@@ -1080,19 +1079,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_activation_memory_budget_w_compile(cls, data):
if data.get("activation_memory_budget") is not None and not data.get(
"torch_compile"
):
LOG.warning(
"activation_memory_budget is enabled, but torch_compile is not set. "
"Automatically setting torch_compile to true."
)
data["torch_compile"] = True
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):