Compare commits
6 Commits
grouped_lr
...
debug-hf-h
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59047ee6c4 | ||
|
|
c1b920f291 | ||
|
|
3915abee4c | ||
|
|
7a38dbe674 | ||
|
|
e0a2eb2ebd | ||
|
|
d852d7af7a |
@@ -23,7 +23,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v2.17.4
|
||||
rev: v3.3.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[MASTER]
|
||||
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
||||
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-positional-arguments, possibly-used-before-assignment
|
||||
|
||||
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.1b2
|
||||
axolotl-contribs-lgpl==0.0.2
|
||||
|
||||
23
setup.py
23
setup.py
@@ -1,4 +1,5 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
@@ -29,15 +30,29 @@ def parse_requirements():
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
print(
|
||||
_install_requires, [req in skip_packages for req in _install_requires]
|
||||
)
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
|
||||
@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
)
|
||||
@click.option(
|
||||
@@ -124,7 +124,7 @@ def inference(
|
||||
if lora_model_dir:
|
||||
kwargs["lora_model_dir"] = lora_model_dir
|
||||
if base_model:
|
||||
kwargs["output_dir"] = base_model
|
||||
kwargs["base_model"] = base_model
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
|
||||
@@ -68,7 +68,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
@@ -1834,8 +1834,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template,
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -126,7 +127,20 @@ def train(
|
||||
)
|
||||
|
||||
if cfg.fix_untrained_tokens:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
|
||||
@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
LOG.info(
|
||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
|
||||
)
|
||||
|
||||
def freeze_all_layers(self):
|
||||
|
||||
@@ -794,7 +794,7 @@ class AxolotlInputConfig(
|
||||
chat_template_jinja: Optional[str] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
fix_untrained_tokens: Optional[bool] = None
|
||||
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
is_preprocess: Optional[bool] = None
|
||||
|
||||
@@ -28,8 +28,10 @@ def encode_pretraining(
|
||||
)
|
||||
# Convert to PyTorch tensors
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
new_input_ids = []
|
||||
new_labels = []
|
||||
new_attention_mask = []
|
||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||
for i, _ in enumerate(input_ids):
|
||||
@@ -40,22 +42,34 @@ def encode_pretraining(
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
targets[i] = torch.cat(
|
||||
(
|
||||
targets[i],
|
||||
torch.tensor([tokenizer.eos_token_id, -100]),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||
|
||||
# Concatenate tokens so that their lengths are less than max_tokens
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
for ids, labels, mask in zip(input_ids, targets, attention_mask):
|
||||
if buffer_input_ids.numel() == max_tokens:
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_labels.append(buffer_labels)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
else:
|
||||
buffer_input_ids = torch.cat(
|
||||
@@ -69,6 +83,17 @@ def encode_pretraining(
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_labels = torch.cat(
|
||||
(
|
||||
buffer_labels,
|
||||
torch.full(
|
||||
(max_tokens - buffer_labels.numel(),),
|
||||
-100,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
@@ -81,11 +106,14 @@ def encode_pretraining(
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_labels.append(buffer_labels)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
|
||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||
@@ -101,6 +129,17 @@ def encode_pretraining(
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_labels = torch.cat(
|
||||
(
|
||||
buffer_labels,
|
||||
torch.full(
|
||||
(max_tokens - buffer_labels.numel(),),
|
||||
-100,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
@@ -113,11 +152,12 @@ def encode_pretraining(
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_labels.append(buffer_labels)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
|
||||
ret = {
|
||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_labels],
|
||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||
}
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ def load_sharded_model_quant(
|
||||
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
|
||||
# cleanup any extra memory usage from parallel loading
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -37,7 +37,8 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
url = snapshot_download(*args, **kwargs)
|
||||
raise f"{args[0]}: {url}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
|
||||
Reference in New Issue
Block a user