Compare commits
1 Commits
sac
...
no-zero-ds
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f1d548534 |
@@ -16,24 +16,15 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
# Keep max_num as a tensor instead of extracting to Python int
|
max_num = int(torch.max(attention_mask).item())
|
||||||
max_num = torch.max(attention_mask)
|
batch_size, _ = attention_mask.shape
|
||||||
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||||
# Create a range tensor for comparison
|
for i in range(1, max_num + 1):
|
||||||
range_tensor = torch.arange(
|
mask = attention_mask == i
|
||||||
1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||||
)
|
|
||||||
|
|
||||||
# Vectorized approach - compare attention_mask with each value in range
|
|
||||||
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
# Sum along sequence dimension to get counts
|
|
||||||
counts = mask.sum(dim=1).to(dtype=torch.int32)
|
|
||||||
|
|
||||||
# Flatten and filter non-zero values
|
|
||||||
result = counts.flatten()
|
result = counts.flatten()
|
||||||
nonzero_mask = result != 0
|
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||||
return result[nonzero_mask]
|
return result[nonzero_indices]
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|||||||
@@ -521,11 +521,6 @@ def train(
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
if cfg.activation_memory_budget is not None:
|
|
||||||
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
|
|
||||||
cfg.activation_memory_budget
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
|
|||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -66,32 +66,31 @@ LOG = logging.getLogger(__name__)
|
|||||||
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_local_main_process()):
|
if cfg.test_datasets:
|
||||||
if cfg.test_datasets:
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
tokenizer,
|
||||||
tokenizer,
|
cfg,
|
||||||
cfg,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
split="train",
|
||||||
split="train",
|
processor=processor,
|
||||||
processor=processor,
|
preprocess_iterable=preprocess_iterable,
|
||||||
preprocess_iterable=preprocess_iterable,
|
)
|
||||||
)
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
_, eval_dataset, _ = load_prepare_datasets(
|
tokenizer,
|
||||||
tokenizer,
|
cfg,
|
||||||
cfg,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
split="test",
|
||||||
split="test",
|
processor=processor,
|
||||||
processor=processor,
|
preprocess_iterable=preprocess_iterable,
|
||||||
preprocess_iterable=preprocess_iterable,
|
)
|
||||||
)
|
else:
|
||||||
else:
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
tokenizer,
|
||||||
tokenizer,
|
cfg,
|
||||||
cfg,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
processor=processor,
|
||||||
processor=processor,
|
preprocess_iterable=preprocess_iterable,
|
||||||
preprocess_iterable=preprocess_iterable,
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Load streaming dataset if pretraining_dataset is given
|
# Load streaming dataset if pretraining_dataset is given
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
@@ -272,7 +271,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
LOG.info("Loading raw datasets...")
|
LOG.info("Loading raw datasets...")
|
||||||
if not cfg.is_preprocess:
|
if not cfg.is_preprocess:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
|
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.seed:
|
if cfg.seed:
|
||||||
|
|||||||
@@ -182,7 +182,6 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
activation_memory_budget: float | None = None
|
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
@@ -1080,19 +1079,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_activation_memory_budget_w_compile(cls, data):
|
|
||||||
if data.get("activation_memory_budget") is not None and not data.get(
|
|
||||||
"torch_compile"
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"activation_memory_budget is enabled, but torch_compile is not set. "
|
|
||||||
"Automatically setting torch_compile to true."
|
|
||||||
)
|
|
||||||
data["torch_compile"] = True
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_npu_config(cls, data):
|
def check_npu_config(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user