bartch upgrade dependencies (#3299)

* upgrade dependencies

* don't use reset sessions

* downgrade transformers, upgrade other deps

* upgrade bnb to 0.49.0

* restore s3 cache

* explicit use local files w hub

* decompress and strip top level dir

* use 2 levels for strip components

* try to preserve permissions for symlinks

* use updated tar

* fix #3293 for distributed

* downgrade bnb

* fast fail after 4

* fix total tokens device

* patch accelerate CP/SP (#3309)

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
Wing Lian
2025-12-30 09:02:49 -05:00
committed by GitHub
parent 66a3de3629
commit 11c0b5b256
9 changed files with 66 additions and 26 deletions

View File

@@ -356,6 +356,7 @@ class AxolotlTrainer(
inputs_key = "labels" if "labels" in inputs else "input_ids"
trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
if is_distributed():
torch.distributed.all_reduce(
@@ -375,9 +376,7 @@ class AxolotlTrainer(
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = (
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()

View File

@@ -75,3 +75,33 @@ def patch_parallelism_config():
ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
def patch_prepare_cp():
import functools
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -645,6 +645,9 @@ def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp
patch_prepare_cp()
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"