Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
2491303c46 improve handling of train len 2025-06-06 22:07:29 -07:00
Wing Lian
2c66483a47 default to dropping last batch in multipack batch sampler 2025-06-05 16:00:24 -07:00
Wing Lian
01382b9a79 fix rebase issues 2025-06-05 15:31:28 -07:00
Wing Lian
cfcd69df0d rename vars for consistency 2025-06-05 15:29:21 -07:00
Wing Lian
2302b14a84 fix to remove attention_mask 2025-06-05 15:29:20 -07:00
Wing Lian
a8e2bddd19 increase hyperparams_count for gradients for added normalize_topk 2025-06-05 15:29:20 -07:00
Wing Lian
d55a51623f more KD updates 2025-06-05 15:29:20 -07:00
Wing Lian
73a84ad0dd post-rebase lint 2025-06-05 15:29:20 -07:00
Wing Lian
3cffe881bb accept compressed responses for smaller wire payload 2025-06-05 15:29:20 -07:00
Wing Lian
e77d62933d Fix decay 2025-06-05 15:29:19 -07:00
14 changed files with 226 additions and 102 deletions

View File

@@ -21,11 +21,6 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback

View File

@@ -33,6 +33,7 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -101,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
sampler = MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -113,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
drop_last=True,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
@@ -220,7 +224,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):

View File

@@ -7,8 +7,6 @@ from typing import Optional
from PIL.Image import Resampling
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections
import importlib
import logging
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
@@ -165,7 +164,6 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
@@ -178,7 +176,7 @@ class BasePlugin:
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool=False
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
@@ -399,7 +397,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
print(f"Training args from plugin: {plugin.__class__.__name__}")
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args

View File

@@ -49,6 +49,7 @@ class KDPlugin(BasePlugin):
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
@@ -72,6 +73,7 @@ class KDPlugin(BasePlugin):
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:

View File

@@ -42,6 +42,9 @@ class KDArgs(BaseModel):
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
@@ -67,3 +70,6 @@ class KDTrainingArgsMixin:
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)

View File

@@ -29,7 +29,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback):
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
self.temperature = self.temperature_start - (
(self.temperature_start - self.temperature_min) * decay_factor
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
)
if hasattr(self.trainer.data_collator, "kd_temperature"):

View File

@@ -12,6 +12,7 @@ import torch
from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
@@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
@@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
@@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if len(raw_logprobs) != self.kd_online_topk:
# This case should be rare if pre-padding/truncation is done correctly
LOG.warning(
f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
)
padded_logprobs = raw_logprobs[: self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend(
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
)
raw_logprobs = padded_logprobs
try:
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(
position_logprobs_tensor, dim=-1, keepdim=True
)
teacher_probs_t_online = torch.exp(
position_logprobs_tensor - position_logprobs_lse
)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = (
teacher_probs_t_online / teacher_probs_t_online_sum
)
else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0:
teacher_probs_t_online = (
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
)
# else: leave as is, will result in -inf logprobs
#
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
# dim=0, keepdim=True
# )
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor.tolist()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
)
# Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
@@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
scaled_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(scaled_logprobs_for_position)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
@@ -336,11 +292,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
ret_data_target_mask: List[List[List[int]]] = []
try:
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
response = self.http_session.post(
api_endpoint,
json=payload,
headers=headers,
timeout=self.kd_online_timeout,
# json_decoder=orjson.loads,
)
response.raise_for_status()
api_data: dict = orjson.loads(response.content)
@@ -441,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
pos_token_ids[: self.kd_online_topk]
)
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
# current_target_logprobs.append(normalized_logprobs_for_position)
# don't normalize for now as the probs seem to sum to 1.0 already
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:

View File

@@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
@@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
@@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns:
Sum of KL divergence losses for the chunk.
"""
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
@@ -56,18 +61,24 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
student_logits_topk_temp_scaled - student_lse
)
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
target_logprobs_valid = target_logprobs_chunk[valid_mask]
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = target_logprobs_valid.exp()
teacher_probs_valid = teacher_logprobs_valid.exp()
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
@@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta < 1.0: # Contribution from Forward KL
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_topk_valid
teacher_logprobs_valid - student_logprobs_topk_valid
)
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
if beta > 0.0: # Contribution from Reverse KL
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - target_logprobs_valid
student_logprobs_topk_valid - teacher_logprobs_valid
)
kd_loss_per_token += beta * rev_kl_per_token
kd_loss = kd_loss_per_token.sum()
kd_loss = rev_kl_per_token.sum()
else:
# JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss
@@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss: bool = True,
temperature: float = 1.0,
beta: float = 0.0,
normalize_topk: bool = True,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
@@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk,
target_mask_chunk,
beta=beta,
normalize_topk=normalize_topk,
)
return soft_loss, ce_loss
@@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
@@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
normalize_topk=normalize_topk,
)
def accumulate_chunk_grads(
@@ -307,11 +337,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
@@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
@@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
@@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled,
self.chunk_size,
self.compute_ce_loss,
self.normalize_topk,
)

View File

@@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
self.args.kd_temperature,
self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
def _set_signature_columns_if_needed(self):
@@ -62,6 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
if (
self.args.sample_packing
and hasattr(inputs, "attention_mask")
and hasattr(inputs, "position_ids")
):
del inputs["attention_mask"]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}

View File

@@ -0,0 +1,100 @@
"""Helper KD utils"""
import math
from typing import List, Union
import numpy as np
import torch
from torch import FloatTensor, Tensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor
def strided_chunk_views(
tensor: Union[np.ndarray, torch.Tensor],
chunks: int,
dim: int = 0,
stride: int = 1,
chunk_size: int | None = None,
) -> List[Union[np.ndarray, torch.Tensor]]:
"""
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
Args:
tensor: Input tensor (numpy array or torch tensor)
chunks: Number of chunks to create
dim: Dimension along which to chunk (default: 0)
stride: Stride between chunk starting positions (default: 1)
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
Returns:
List of tensor chunks (views when possible, copies when necessary)
"""
# Get the size of the specified dimension
dim_size = tensor.shape[dim]
# Calculate chunk size if not provided
if chunk_size is None:
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
chunks_list = []
for i in range(chunks):
start_idx = i * stride
end_idx = min(start_idx + chunk_size, dim_size)
# Break if we've gone beyond the tensor
if start_idx >= dim_size:
break
# Create slice objects for all dimensions
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start_idx, end_idx)
chunk = tensor[tuple(slices)]
chunks_list.append(chunk)
return chunks_list
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
dim_size = input_tensor.shape[dim]
stride = math.ceil(dim_size / chunks)
return strided_chunk_views(
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
)

View File

@@ -52,3 +52,10 @@ def patch_optimized_env():
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value
"""
return value if value is not None else default

View File

@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
len_batches = min(_sampled_lens)
# Gather minimum across all ranks
self._len_across_ranks = self.gather_len_batches(len_batches)
if self._len_across_ranks is None:
self._len_across_ranks = self.gather_len_batches(len_batches)
else:
self._len_across_ranks = min(
self._len_across_ranks, self.gather_len_batches(len_batches)
)
return self._len_across_ranks

View File

@@ -481,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -628,7 +631,7 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if (
cfg.torch_compile