Compare commits
1 Commits
debug-hf-h
...
grouped_lr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd7f087b3 |
@@ -23,7 +23,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/pylint
|
- repo: https://github.com/PyCQA/pylint
|
||||||
rev: v3.3.0
|
rev: v2.17.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[MASTER]
|
[MASTER]
|
||||||
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
|
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
||||||
|
|
||||||
[TYPECHECK]
|
[TYPECHECK]
|
||||||
|
|
||||||
@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
too-many-positional-arguments, possibly-used-before-assignment
|
|
||||||
|
|||||||
29
docs/lr_groups.qmd
Normal file
29
docs/lr_groups.qmd
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
---
|
||||||
|
title: Learning Rate Groups
|
||||||
|
description: "Setting different learning rates by module name"
|
||||||
|
---
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
||||||
|
modules in a model.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lr_groups:
|
||||||
|
- name: o_proj
|
||||||
|
modules:
|
||||||
|
- self_attn.o_proj.weight
|
||||||
|
lr: 1e-6
|
||||||
|
- name: q_proj
|
||||||
|
modules:
|
||||||
|
- model.layers.2.self_attn.q_proj.weight
|
||||||
|
lr: 1e-5
|
||||||
|
|
||||||
|
learning_rate: 2e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||||
|
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||||
|
self attention `q_proj` module.
|
||||||
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.2
|
axolotl-contribs-lgpl==0.0.1b2
|
||||||
|
|||||||
23
setup.py
23
setup.py
@@ -1,5 +1,4 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -30,29 +29,15 @@ def parse_requirements():
|
|||||||
elif not is_extras and line and line[0] != "#":
|
elif not is_extras and line and line[0] != "#":
|
||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||||
|
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# skip packages not compatible with OSX
|
# don't install xformers on MacOS
|
||||||
skip_packages = [
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
"bitsandbytes",
|
|
||||||
"triton",
|
|
||||||
"mamba-ssm",
|
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
|
||||||
"autoawq",
|
|
||||||
"liger-kernel",
|
|
||||||
]
|
|
||||||
_install_requires = [
|
|
||||||
req
|
|
||||||
for req in _install_requires
|
|
||||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
|
||||||
]
|
|
||||||
print(
|
|
||||||
_install_requires, [req in skip_packages for req in _install_requires]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--accelerate/--no-accelerate",
|
"--accelerate/--no-accelerate",
|
||||||
default=False,
|
default=True,
|
||||||
help="Use accelerate launch for multi-GPU inference",
|
help="Use accelerate launch for multi-GPU inference",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@@ -124,7 +124,7 @@ def inference(
|
|||||||
if lora_model_dir:
|
if lora_model_dir:
|
||||||
kwargs["lora_model_dir"] = lora_model_dir
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
if base_model:
|
if base_model:
|
||||||
kwargs["base_model"] = base_model
|
kwargs["output_dir"] = base_model
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
@@ -244,6 +244,10 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
)
|
)
|
||||||
|
lr_groups: Optional[list[dict]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||||
|
)
|
||||||
embedding_lr: Optional[float] = field(
|
embedding_lr: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
@@ -462,35 +466,23 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.alternate_optimizer
|
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
params = {
|
params = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
"no_weight_decay": {},
|
"no_weight_decay": {},
|
||||||
}
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
lr_groups_learning_rates = {}
|
||||||
self.args,
|
if self.args.lr_groups:
|
||||||
opt_model,
|
for lr_group in self.args.lr_groups:
|
||||||
)
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
for name, param in opt_model.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
@@ -500,6 +492,17 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
):
|
):
|
||||||
params["embeddings"][name] = param
|
params["embeddings"][name] = param
|
||||||
elif name in decay_parameters:
|
elif name in decay_parameters:
|
||||||
|
if lr_groups_lookup and any(
|
||||||
|
group_modules in name for group_modules in lr_groups_lookup
|
||||||
|
):
|
||||||
|
lr_group_module = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
][0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
params["to_weight_decay"][name] = param
|
params["to_weight_decay"][name] = param
|
||||||
else:
|
else:
|
||||||
params["no_weight_decay"][name] = param
|
params["no_weight_decay"][name] = param
|
||||||
@@ -533,6 +536,46 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
"lr": optimizer_kwargs["lr"],
|
"lr": optimizer_kwargs["lr"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.args.alternate_optimizer
|
||||||
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
@@ -549,6 +592,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif (
|
elif (
|
||||||
self.args.embedding_lr_scale is not None
|
self.args.embedding_lr_scale is not None
|
||||||
or self.args.embedding_lr is not None
|
or self.args.embedding_lr is not None
|
||||||
|
or self.args.lr_groups is not None
|
||||||
):
|
):
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
@@ -1764,6 +1808,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
@@ -1834,8 +1879,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
cfg=self.cfg,
|
self.cfg.chat_template,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -127,19 +126,6 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.fix_untrained_tokens:
|
if cfg.fix_untrained_tokens:
|
||||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
|
||||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
||||||
cfg.fix_untrained_tokens, list
|
|
||||||
):
|
|
||||||
fix_untrained_tokens(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
|
|||||||
@@ -145,6 +145,14 @@ class UserDefinedPrompterType(BaseModel):
|
|||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: List[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
class SFTDataset(BaseModel):
|
||||||
"""SFT configuration subset"""
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
@@ -466,6 +474,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
cosine_constant_lr_ratio: Optional[float] = None
|
cosine_constant_lr_ratio: Optional[float] = None
|
||||||
lr_div_factor: Optional[float] = None
|
lr_div_factor: Optional[float] = None
|
||||||
|
lr_groups: Optional[List[LrGroup]] = None
|
||||||
|
|
||||||
adam_epsilon: Optional[float] = None
|
adam_epsilon: Optional[float] = None
|
||||||
adam_beta1: Optional[float] = None
|
adam_beta1: Optional[float] = None
|
||||||
@@ -794,7 +803,7 @@ class AxolotlInputConfig(
|
|||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|||||||
@@ -28,10 +28,8 @@ def encode_pretraining(
|
|||||||
)
|
)
|
||||||
# Convert to PyTorch tensors
|
# Convert to PyTorch tensors
|
||||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
|
||||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||||
new_input_ids = []
|
new_input_ids = []
|
||||||
new_labels = []
|
|
||||||
new_attention_mask = []
|
new_attention_mask = []
|
||||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||||
for i, _ in enumerate(input_ids):
|
for i, _ in enumerate(input_ids):
|
||||||
@@ -42,34 +40,22 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
targets[i] = torch.cat(
|
|
||||||
(
|
|
||||||
targets[i],
|
|
||||||
torch.tensor([tokenizer.eos_token_id, -100]),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||||
|
|
||||||
# Concatenate tokens so that their lengths are less than max_tokens
|
# Concatenate tokens so that their lengths are less than max_tokens
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
for ids, labels, mask in zip(input_ids, targets, attention_mask):
|
for ids, mask in zip(input_ids, attention_mask):
|
||||||
if buffer_input_ids.numel() == max_tokens:
|
if buffer_input_ids.numel() == max_tokens:
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
else:
|
else:
|
||||||
buffer_input_ids = torch.cat(
|
buffer_input_ids = torch.cat(
|
||||||
@@ -83,17 +69,6 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
buffer_labels = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_labels,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_labels.numel(),),
|
|
||||||
-100,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
buffer_attention_mask = torch.cat(
|
||||||
(
|
(
|
||||||
buffer_attention_mask,
|
buffer_attention_mask,
|
||||||
@@ -106,14 +81,11 @@ def encode_pretraining(
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
|
|
||||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||||
@@ -129,17 +101,6 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
buffer_labels = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_labels,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_labels.numel(),),
|
|
||||||
-100,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
buffer_attention_mask = torch.cat(
|
||||||
(
|
(
|
||||||
buffer_attention_mask,
|
buffer_attention_mask,
|
||||||
@@ -152,12 +113,11 @@ def encode_pretraining(
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||||
"labels": [seq.tolist() for seq in new_labels],
|
"labels": [seq.tolist() for seq in new_input_ids],
|
||||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,7 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def snapshot_download_w_retry(*args, **kwargs):
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
url = snapshot_download(*args, **kwargs)
|
return snapshot_download(*args, **kwargs)
|
||||||
raise f"{args[0]}: {url}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user