update transformers to 4.53.1 (#2844) [skip ci]
* update transformers to 4.53.0 * remove attention_mask from signature columns if using packing * remove attention_mask column from dataloader * update signature of flash attn forward for ring attn patch * fix FSDP * patch ring-flash-attn with upstream signature fix * fix patch indentation level * fix the patch * add batch flattening smoke test with loss check that works in older transformers * fix patch * don't drop attention mask for flex * more fixes * patch create_causal_mask for packing w flex * global torch manual_seed fixture * tweak loss checks * fix patch and use single batch for flex * don't need to reload * fix causal mask patch * use transformers patch releasE * make sure env var is string * make sure to drop attention mask for flex w packing for latest transformers patch release * tweak loss * guard on signature columns before removing attention mask * bump loss * set remove isn't chainable * skip slow mistral test in 2.5.1
This commit is contained in:
0
src/axolotl/core/attention/__init__.py
Normal file
0
src/axolotl/core/attention/__init__.py
Normal file
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(
|
||||
batch_idx, head_idx, q_idx, kv_idx
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
config._attn_implementation # pylint: disable=protected-access
|
||||
]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -245,10 +245,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.xformers_attention
|
||||
or self.cfg.flex_attention
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not self.cfg.flash_attention
|
||||
else not (
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.xformers_attention
|
||||
)
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
from axolotl.core.trainers.mixins import (
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
@@ -42,7 +43,12 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||
PackingMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
@@ -206,6 +212,14 @@ class AxolotlTrainer(
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
|
||||
@@ -5,5 +5,6 @@
|
||||
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Trainer mixin to support packing"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class PackingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin to support packing
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
if (
|
||||
self._signature_columns
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
set_sig_columns = set(self._signature_columns)
|
||||
set_sig_columns.remove("attention_mask")
|
||||
self._signature_columns = list(set_sig_columns)
|
||||
@@ -42,6 +42,10 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "The multiprocessing start method to use."},
|
||||
)
|
||||
sample_packing_drop_attention_mask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Drop attention mask from inputs when using packing."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -49,11 +49,11 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_flex_attention_patches()
|
||||
self._apply_model_specific_patches()
|
||||
self._apply_fp8_patches()
|
||||
self._apply_flash_attention_peft_patches()
|
||||
@@ -97,6 +97,14 @@ class PatchManager:
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
|
||||
# if self.cfg.fsdp_config:
|
||||
# # see transformers#39152
|
||||
# from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
# patch_training_loop_for_fsdp,
|
||||
# )
|
||||
#
|
||||
# patch_training_loop_for_fsdp()
|
||||
|
||||
def _apply_adapter_patches(self):
|
||||
"""Apply patches for adapter configurations."""
|
||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||
@@ -107,14 +115,20 @@ class PatchManager:
|
||||
def _apply_flex_attention_patches(self):
|
||||
"""Apply patches for flexible attention."""
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
||||
# patch_flex_make_mask,
|
||||
# patch_flex_wrapper,
|
||||
# )
|
||||
#
|
||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
# patch_flex_make_mask()
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
)
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
patch_flex_make_mask()
|
||||
patch_create_causal_mask(self.cfg.model_config_type)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
|
||||
@@ -33,7 +33,7 @@ RING_ATTN_FUNC_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def create_flash_attn_forward(
|
||||
def create_flash_attn_forward_varlen_llama3(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
) -> Callable:
|
||||
"""
|
||||
@@ -71,6 +71,7 @@ def create_flash_attn_forward(
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
attn_implementation: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -97,6 +98,7 @@ def create_flash_attn_forward(
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
attn_implementation: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
@@ -161,7 +163,7 @@ def substitute_hf_flash_attn(
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_flash_attn_forward(
|
||||
new_flash_attention_forward = create_flash_attn_forward_varlen_llama3(
|
||||
process_group=process_group, ring_attn_func=ring_attn_func
|
||||
)
|
||||
|
||||
|
||||
@@ -9,10 +9,13 @@ sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -62,6 +65,96 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def create_ring_flash_attention_forward(
|
||||
process_group: dist.ProcessGroup, heads_k_stride: int
|
||||
):
|
||||
from ring_flash_attn import llama3_flash_attn_varlen_func
|
||||
from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS
|
||||
|
||||
def _flash_attention_forward_v3(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: Optional[
|
||||
torch.LongTensor
|
||||
] = None, # pylint: disable=unused-argument
|
||||
cu_seq_lens_k: Optional[
|
||||
torch.LongTensor
|
||||
] = None, # pylint: disable=unused-argument
|
||||
max_length_q: Optional[int] = None, # pylint: disable=unused-argument
|
||||
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||
target_dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument
|
||||
attn_implementation: Optional[str] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
flash_kwargs = (
|
||||
{"window_size": (sliding_window, sliding_window)}
|
||||
if use_sliding_windows
|
||||
else {}
|
||||
)
|
||||
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
assert (
|
||||
softcap is None
|
||||
), "llama3_flash_attn_varlen_func does not support softcap yet."
|
||||
# flash_kwargs["softcap"] = softcap
|
||||
flash_kwargs["group"] = process_group
|
||||
|
||||
# not sure why attention_mask can be not None...
|
||||
assert causal, "only causal attention is supported yet."
|
||||
batch_size = query_states.size(0)
|
||||
assert batch_size == 1, "varlen data should be processed in advance."
|
||||
|
||||
attn_output = llama3_flash_attn_varlen_func(
|
||||
query_states.squeeze(dim=0),
|
||||
key_states.squeeze(dim=0),
|
||||
value_states.squeeze(dim=0),
|
||||
cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
|
||||
cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
|
||||
max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
|
||||
max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
|
||||
heads_k_stride=heads_k_stride,
|
||||
local_k_slice=DATA_PARAMS["local_k_slice"],
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.unsqueeze(dim=0)
|
||||
|
||||
return attn_output
|
||||
|
||||
return [
|
||||
_flash_attention_forward_v3,
|
||||
]
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
@@ -118,9 +211,20 @@ def register_ring_attn(
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
# fmt: off
|
||||
import ring_flash_attn.adapters.hf_adapter
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
from ring_flash_attn.adapters.hf_adapter import ( # isort: skip # pylint: disable=unused-import
|
||||
create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,
|
||||
)
|
||||
|
||||
create_ring_flash_attention_forward_orig = ( # noqa: F811,F841
|
||||
create_ring_flash_attention_forward
|
||||
)
|
||||
ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward
|
||||
# fmt: on
|
||||
|
||||
ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func is RingAttnFunc.BATCH_RING:
|
||||
|
||||
@@ -12,15 +12,13 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
|
||||
if delay_optimizer_creation:
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
if delay_optimizer_creation:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
dataset_processes: int | None = Field(
|
||||
default=min(32, os.cpu_count()), # type: ignore[type-var]
|
||||
default=min(int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()), # type: ignore[type-var]
|
||||
json_schema_extra={
|
||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
||||
},
|
||||
|
||||
@@ -535,6 +535,9 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
|
||||
cfg.gradient_accumulation_steps
|
||||
)
|
||||
if stage:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||
if stage == 3:
|
||||
|
||||
Reference in New Issue
Block a user