bartch upgrade dependencies (#3299)
* upgrade dependencies * don't use reset sessions * downgrade transformers, upgrade other deps * upgrade bnb to 0.49.0 * restore s3 cache * explicit use local files w hub * decompress and strip top level dir * use 2 levels for strip components * try to preserve permissions for symlinks * use updated tar * fix #3293 for distributed * downgrade bnb * fast fail after 4 * fix total tokens device * patch accelerate CP/SP (#3309) --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -356,6 +356,7 @@ class AxolotlTrainer(
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||
total_tokens = inputs[inputs_key].numel()
|
||||
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
|
||||
|
||||
if is_distributed():
|
||||
torch.distributed.all_reduce(
|
||||
@@ -375,9 +376,7 @@ class AxolotlTrainer(
|
||||
self.state.tokens["trainable"] = (
|
||||
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||
)
|
||||
self.state.tokens["total"] = (
|
||||
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
|
||||
)
|
||||
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
|
||||
|
||||
@@ -75,3 +75,33 @@ def patch_parallelism_config():
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
|
||||
return args
|
||||
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
@@ -645,6 +645,9 @@ def setup_parallelism_envs(cfg):
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
||||
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
||||
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp
|
||||
|
||||
patch_prepare_cp()
|
||||
if set_accelerate_parallelism_config:
|
||||
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user