Compare commits

..

6 Commits

Author SHA1 Message Date
Wing Lian
59047ee6c4 dump snapshot location for caching 2025-01-09 11:26:33 -05:00
salman
c1b920f291 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-07 13:42:01 +00:00
Wing Lian
3915abee4c make sure padding is labeled as -100 for pretraining (#2227) 2024-12-31 15:22:18 -05:00
NJordan72
7a38dbe674 fix: allow trainer builder to use custom jinja chat template (#2219)
* fix: allow trainer builder to use custom jinja chat template

* chore: use get_chat_template_from_config

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>

* fix: swap imports

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-12-24 16:18:50 -05:00
Wing Lian
e0a2eb2ebd fix untrained tokens if specified explicitly from a list (#2210) 2024-12-23 09:08:28 -05:00
Wing Lian
d852d7af7a inference - don't default w accelerate, fix base model (#2216) [skip ci] 2024-12-23 07:48:41 -05:00
13 changed files with 139 additions and 151 deletions

View File

@@ -23,7 +23,7 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/pylint - repo: https://github.com/PyCQA/pylint
rev: v2.17.4 rev: v3.3.0
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER] [MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK] [TYPECHECK]
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -1,29 +0,0 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.1b2 axolotl-contribs-lgpl==0.0.2

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast import ast
import os import os
import platform import platform
@@ -29,15 +30,29 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#": elif not is_extras and line and line[0] != "#":
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0] autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # skip packages not compatible with OSX
_install_requires.pop(_install_requires.index(xformers_version)) skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else: else:
# detect the version of torch already installed # detect the version of torch already installed
# and set it so dependencies don't clobber the torch version # and set it so dependencies don't clobber the torch version

View File

@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option( @click.option(
"--accelerate/--no-accelerate", "--accelerate/--no-accelerate",
default=True, default=False,
help="Use accelerate launch for multi-GPU inference", help="Use accelerate launch for multi-GPU inference",
) )
@click.option( @click.option(
@@ -124,7 +124,7 @@ def inference(
if lora_model_dir: if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir kwargs["lora_model_dir"] = lora_model_dir
if base_model: if base_model:
kwargs["output_dir"] = base_model kwargs["base_model"] = base_model
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]

View File

@@ -68,7 +68,7 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
@@ -244,10 +244,6 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "Scale the learning rate for the embedding layers."}, metadata={"help": "Scale the learning rate for the embedding layers."},
) )
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field( embedding_lr: Optional[float] = field(
default=None, default=None,
metadata={"help": "absolute learning rate for the embedding layers."}, metadata={"help": "absolute learning rate for the embedding layers."},
@@ -466,96 +462,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
) )
return super()._wrap_model(model, training=training, dataloader=dataloader) return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
if lr_groups_lookup and any(
group_modules in name for group_modules in lr_groups_lookup
):
lr_group_module = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
][0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in [ not in [
"optimi_adamw", "optimi_adamw",
@@ -569,13 +480,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs for name, param in opt_model.named_parameters():
) if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None: if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -592,7 +549,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif ( elif (
self.args.embedding_lr_scale is not None self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None or self.args.embedding_lr is not None
or self.args.lr_groups is not None
): ):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -1808,7 +1764,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1879,8 +1834,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template: if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template( training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
self.cfg.chat_template, cfg=self.cfg,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
) )

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os import os
import signal import signal
import sys import sys
@@ -126,7 +127,20 @@ def train(
) )
if cfg.fix_untrained_tokens: if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset) # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model getattr, self.layers_attribute.split("."), self.trainer.model
) )
LOG.info( LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
) )
def freeze_all_layers(self): def freeze_all_layers(self):

View File

@@ -145,14 +145,6 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel): class SFTDataset(BaseModel):
"""SFT configuration subset""" """SFT configuration subset"""
@@ -474,7 +466,6 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None adam_beta1: Optional[float] = None
@@ -803,7 +794,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None fix_untrained_tokens: Optional[Union[int, List[int]]] = None
# INTERNALS - document for now, generally not set externally # INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None is_preprocess: Optional[bool] = None

View File

@@ -28,8 +28,10 @@ def encode_pretraining(
) )
# Convert to PyTorch tensors # Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = [] new_input_ids = []
new_labels = []
new_attention_mask = [] new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask # Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids): for i, _ in enumerate(input_ids):
@@ -40,22 +42,34 @@ def encode_pretraining(
), ),
dim=0, dim=0,
) )
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens # Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask): for ids, labels, mask in zip(input_ids, targets, attention_mask):
if buffer_input_ids.numel() == max_tokens: if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids) new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask) new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens: elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else: else:
buffer_input_ids = torch.cat( buffer_input_ids = torch.cat(
@@ -69,6 +83,17 @@ def encode_pretraining(
), ),
dim=0, dim=0,
) )
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat( buffer_attention_mask = torch.cat(
( (
buffer_attention_mask, buffer_attention_mask,
@@ -81,11 +106,14 @@ def encode_pretraining(
dim=0, dim=0,
) )
new_input_ids.append(buffer_input_ids) new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask) new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens if buffer_input_ids.numel() > 0: # for any leftover tokens
@@ -101,6 +129,17 @@ def encode_pretraining(
), ),
dim=0, dim=0,
) )
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat( buffer_attention_mask = torch.cat(
( (
buffer_attention_mask, buffer_attention_mask,
@@ -113,11 +152,12 @@ def encode_pretraining(
dim=0, dim=0,
) )
new_input_ids.append(buffer_input_ids) new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask) new_attention_mask.append(buffer_attention_mask)
ret = { ret = {
"input_ids": [seq.tolist() for seq in new_input_ids], "input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids], "labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask], "attention_mask": [seq.tolist() for seq in new_attention_mask],
} }

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config) model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds") print(f"Loaded model weights in {time.time() - start:.3f} seconds")
# cleanup any extra memory usage from parallel loading # cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -37,7 +37,8 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs): def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs) url = snapshot_download(*args, **kwargs)
raise f"{args[0]}: {url}"
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)