Compare commits
1 Commits
cli-cloud-
...
grouped_lr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd7f087b3 |
@@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
29
docs/lr_groups.qmd
Normal file
29
docs/lr_groups.qmd
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
---
|
||||||
|
title: Learning Rate Groups
|
||||||
|
description: "Setting different learning rates by module name"
|
||||||
|
---
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
||||||
|
modules in a model.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lr_groups:
|
||||||
|
- name: o_proj
|
||||||
|
modules:
|
||||||
|
- self_attn.o_proj.weight
|
||||||
|
lr: 1e-6
|
||||||
|
- name: q_proj
|
||||||
|
modules:
|
||||||
|
- model.layers.2.self_attn.q_proj.weight
|
||||||
|
lr: 1e-5
|
||||||
|
|
||||||
|
learning_rate: 2e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||||
|
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||||
|
self attention `q_proj` module.
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
volumes:
|
|
||||||
- name: axolotl-data
|
|
||||||
mount: /workspace/data
|
|
||||||
- name: axolotl-artifacts
|
|
||||||
mount: /workspace/artifacts
|
|
||||||
secrets:
|
|
||||||
- HF_TOKEN
|
|
||||||
- WANDB_API_KEY
|
|
||||||
branch: cli-cloud-modal
|
|
||||||
gpu: h100
|
|
||||||
gpu_count: 1
|
|
||||||
memory: 128
|
|
||||||
timeout: 86400
|
|
||||||
timeout_preprocess: 14400
|
|
||||||
memory_preprocess: 32
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
lm_eval_model: axolotl-ai-co/numina-8b-ep1-exp1
|
|
||||||
lm_eval_tasks:
|
|
||||||
- leaderboard_math_hard
|
|
||||||
lm_eval_batch_size: 64
|
|
||||||
|
|
||||||
apply_chat_template: false
|
|
||||||
wandb_project: numina-kd-experiment
|
|
||||||
wandb_entity: axolotl-ai
|
|
||||||
bf16: true
|
|
||||||
flash_attention: true
|
|
||||||
output_dir: ./outputs/model-evals-out
|
|
||||||
@@ -25,7 +25,6 @@ hf_transfer
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
|
|
||||||
modal==0.70.5
|
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
@@ -54,7 +53,7 @@ zstandard==0.22.0
|
|||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
# lm eval harness
|
# lm eval harness
|
||||||
lm_eval==0.4.7
|
lm_eval==0.4.4
|
||||||
langdetect==1.0.9
|
langdetect==1.0.9
|
||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
@@ -62,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.2
|
axolotl-contribs-lgpl==0.0.1b2
|
||||||
|
|||||||
17
scripts/motd
17
scripts/motd
@@ -1,15 +1,10 @@
|
|||||||
|
|
||||||
#@@ #@@ @@# @@#
|
dP dP dP
|
||||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
88 88 88
|
||||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
|
||||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
|
||||||
@@@@ @@@@@@@@@@@@@@@@
|
|
||||||
|
|
||||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||||
|
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
"""
|
|
||||||
launch axolotl in supported cloud platforms
|
|
||||||
"""
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from axolotl.cli import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
|
||||||
"""Load and validate cloud configuration."""
|
|
||||||
# Load cloud configuration.
|
|
||||||
with open(cloud_config, encoding="utf-8") as file:
|
|
||||||
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
||||||
return cloud_cfg
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli_preprocess(
|
|
||||||
cloud_config: Union[Path, str],
|
|
||||||
config: Union[Path, str] = Path("examples/"),
|
|
||||||
) -> None:
|
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
|
||||||
cloud = ModalCloud(cloud_cfg)
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
|
||||||
config_yaml = file.read()
|
|
||||||
cloud.preprocess(config_yaml)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli_train(
|
|
||||||
cloud_config: Union[Path, str],
|
|
||||||
config: Union[Path, str] = Path("examples/"),
|
|
||||||
accelerate: bool = True,
|
|
||||||
) -> None:
|
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
|
||||||
cloud = ModalCloud(cloud_cfg)
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
|
||||||
config_yaml = file.read()
|
|
||||||
cloud.train(config_yaml, accelerate=accelerate)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli_lm_eval(
|
|
||||||
cloud_config: Union[Path, str],
|
|
||||||
config: Union[Path, str] = Path("examples/"),
|
|
||||||
) -> None:
|
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
|
||||||
cloud = ModalCloud(cloud_cfg)
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
|
||||||
config_yaml = file.read()
|
|
||||||
cloud.lm_eval(config_yaml)
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
base class for cloud platforms from cli
|
|
||||||
"""
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class Cloud(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for cloud platforms.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
|
||||||
pass
|
|
||||||
@@ -1,272 +0,0 @@
|
|||||||
"""
|
|
||||||
Modal Cloud support from CLI
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import subprocess # nosec B404
|
|
||||||
from pathlib import Path
|
|
||||||
from random import randint
|
|
||||||
|
|
||||||
import modal
|
|
||||||
|
|
||||||
from axolotl.cli.cloud.base import Cloud
|
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(cmd: str, run_folder: str, volumes=None):
|
|
||||||
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
|
|
||||||
# Ensure volumes contain latest files.
|
|
||||||
if volumes:
|
|
||||||
for _, vol in volumes.items():
|
|
||||||
vol.reload()
|
|
||||||
|
|
||||||
# modal workaround so it doesn't use the automounted axolotl
|
|
||||||
new_env = copy.deepcopy(os.environ)
|
|
||||||
if "PYTHONPATH" in new_env:
|
|
||||||
del new_env["PYTHONPATH"]
|
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
|
||||||
if exit_code := subprocess.call( # nosec B603
|
|
||||||
cmd.split(), cwd=run_folder, env=new_env
|
|
||||||
):
|
|
||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
|
||||||
|
|
||||||
# Commit writes to volume.
|
|
||||||
if volumes:
|
|
||||||
for _, vol in volumes.items():
|
|
||||||
vol.commit()
|
|
||||||
|
|
||||||
|
|
||||||
class ModalCloud(Cloud):
|
|
||||||
"""
|
|
||||||
Modal Cloud implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, app=None):
|
|
||||||
self.config = config
|
|
||||||
if not app:
|
|
||||||
app = modal.App()
|
|
||||||
self.app = app
|
|
||||||
|
|
||||||
self.volumes = {}
|
|
||||||
if config.volumes:
|
|
||||||
for volume_config in config.volumes:
|
|
||||||
_, mount, vol = self.create_volume(volume_config)
|
|
||||||
self.volumes[mount] = (vol, volume_config)
|
|
||||||
|
|
||||||
def get_env(self):
|
|
||||||
res = {
|
|
||||||
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
|
|
||||||
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key in self.config.get("env", []):
|
|
||||||
if isinstance(key, str):
|
|
||||||
if val := os.environ.get(key, ""):
|
|
||||||
res[key] = val
|
|
||||||
elif isinstance(key, dict):
|
|
||||||
(key_, val) = list(key.items())[0]
|
|
||||||
res[key_] = val
|
|
||||||
return res
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
docker_tag = "main-py3.11-cu124-2.5.1"
|
|
||||||
if self.config.docker_tag:
|
|
||||||
docker_tag = self.config.docker_tag
|
|
||||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
|
||||||
|
|
||||||
# grab the sha256 hash from docker hub for this image+tag
|
|
||||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
|
||||||
try:
|
|
||||||
manifest = subprocess.check_output( # nosec B602
|
|
||||||
f"docker manifest inspect {docker_image}",
|
|
||||||
shell=True,
|
|
||||||
).decode("utf-8")
|
|
||||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
sha256_hash = None
|
|
||||||
|
|
||||||
# create the image
|
|
||||||
if sha256_hash:
|
|
||||||
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
|
|
||||||
else:
|
|
||||||
image = modal.Image.from_registry(docker_image)
|
|
||||||
|
|
||||||
# branch
|
|
||||||
if self.config.branch:
|
|
||||||
image = image.dockerfile_commands(
|
|
||||||
[
|
|
||||||
# Random id for cache busting of branch commits
|
|
||||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
|
||||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
|
||||||
"RUN cd /workspace/ && git clone https://github.com/winglian/lm-evaluation-harness.git && cd lm-evaluation-harness && pip install -e .[math]",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if env := self.get_env():
|
|
||||||
image = image.env(env)
|
|
||||||
|
|
||||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def get_secrets(self):
|
|
||||||
res = []
|
|
||||||
if self.config.secrets:
|
|
||||||
for key in self.config.get("secrets", []):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
if isinstance(key, str):
|
|
||||||
if val := os.environ.get(key, ""):
|
|
||||||
res.append(modal.Secret.from_dict({key: val}))
|
|
||||||
elif isinstance(key, dict):
|
|
||||||
(key_, val) = list(key.items())[0]
|
|
||||||
res.append(modal.Secret.from_dict({key_: val}))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def create_volume(self, volume_config):
|
|
||||||
name = volume_config.name
|
|
||||||
mount = volume_config.mount
|
|
||||||
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
|
|
||||||
|
|
||||||
def get_ephemeral_disk_size(self):
|
|
||||||
return 1000 * 525 # 1 TiB
|
|
||||||
|
|
||||||
def get_preprocess_timeout(self):
|
|
||||||
if self.config.timeout_preprocess:
|
|
||||||
return int(self.config.timeout_preprocess)
|
|
||||||
return 60 * 60 * 3 # 3 hours
|
|
||||||
|
|
||||||
def get_preprocess_memory(self):
|
|
||||||
memory = 128 # default to 128GiB
|
|
||||||
if self.config.memory:
|
|
||||||
memory = int(self.config.memory)
|
|
||||||
if self.config.memory_preprocess:
|
|
||||||
memory = int(self.config.memory_preprocess)
|
|
||||||
return 1024 * memory
|
|
||||||
|
|
||||||
def get_preprocess_env(self):
|
|
||||||
return self.app.function(
|
|
||||||
image=self.get_image(),
|
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
|
||||||
cpu=8.0,
|
|
||||||
ephemeral_disk=self.get_ephemeral_disk_size(),
|
|
||||||
memory=self.get_preprocess_memory(),
|
|
||||||
timeout=self.get_preprocess_timeout(),
|
|
||||||
secrets=self.get_secrets(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess(self, config_yaml: str, *args, **kwargs):
|
|
||||||
modal_fn = self.get_preprocess_env()(_preprocess)
|
|
||||||
with modal.enable_output():
|
|
||||||
with self.app.run(detach=True):
|
|
||||||
modal_fn.remote(
|
|
||||||
config_yaml,
|
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_train_timeout(self):
|
|
||||||
if self.config.timeout:
|
|
||||||
return int(self.config.timeout)
|
|
||||||
return 60 * 60 * 24 # 24 hours
|
|
||||||
|
|
||||||
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
|
||||||
count = self.config.gpu_count or 1
|
|
||||||
family = self.config.gpu.lower() or "l40s"
|
|
||||||
|
|
||||||
if family == "l40s":
|
|
||||||
return modal.gpu.L40S(count=count)
|
|
||||||
if family == "a100":
|
|
||||||
return modal.gpu.A100(count=count, size="40GB")
|
|
||||||
if family == "a100-80gb":
|
|
||||||
return modal.gpu.A100(count=count, size="80GB")
|
|
||||||
if family in ["a10", "a10g"]:
|
|
||||||
return modal.gpu.A10G(count=count)
|
|
||||||
if family == "h100":
|
|
||||||
return modal.gpu.H100(count=count)
|
|
||||||
if family == "t4":
|
|
||||||
return modal.gpu.T4(count=count)
|
|
||||||
if family == "l4":
|
|
||||||
return modal.gpu.L4(count=count)
|
|
||||||
raise ValueError(f"Unsupported GPU family: {family}")
|
|
||||||
|
|
||||||
def get_train_memory(self):
|
|
||||||
memory = 128 # default to 128GiB
|
|
||||||
if self.config.memory:
|
|
||||||
memory = int(self.config.memory)
|
|
||||||
return 1024 * memory
|
|
||||||
|
|
||||||
def get_train_env(self):
|
|
||||||
return self.app.function(
|
|
||||||
image=self.get_image(),
|
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
|
||||||
cpu=16.0,
|
|
||||||
gpu=self.get_train_gpu(),
|
|
||||||
memory=self.get_train_memory(),
|
|
||||||
timeout=self.get_train_timeout(),
|
|
||||||
secrets=self.get_secrets(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def train(self, config_yaml: str, accelerate: bool = True):
|
|
||||||
modal_fn = self.get_train_env()(_train)
|
|
||||||
with modal.enable_output():
|
|
||||||
with self.app.run(detach=True):
|
|
||||||
modal_fn.remote(
|
|
||||||
config_yaml,
|
|
||||||
accelerate=accelerate,
|
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
|
||||||
)
|
|
||||||
|
|
||||||
def lm_eval(self, config_yaml: str):
|
|
||||||
modal_fn = self.get_train_env()(_lm_eval)
|
|
||||||
with modal.enable_output():
|
|
||||||
with self.app.run(detach=True):
|
|
||||||
modal_fn.remote(
|
|
||||||
config_yaml,
|
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _preprocess(config_yaml: str, volumes=None):
|
|
||||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(
|
|
||||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
|
||||||
) as f_out:
|
|
||||||
f_out.write(config_yaml)
|
|
||||||
run_folder = "/workspace/artifacts/axolotl"
|
|
||||||
run_cmd(
|
|
||||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
|
||||||
run_folder,
|
|
||||||
volumes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
|
||||||
with open(
|
|
||||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
|
||||||
) as f_out:
|
|
||||||
f_out.write(config_yaml)
|
|
||||||
run_folder = "/workspace/artifacts/axolotl"
|
|
||||||
if accelerate:
|
|
||||||
accelerate_args = "--accelerate"
|
|
||||||
else:
|
|
||||||
accelerate_args = "--no-accelerate"
|
|
||||||
run_cmd(
|
|
||||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
|
||||||
run_folder,
|
|
||||||
volumes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _lm_eval(config_yaml: str, volumes=None):
|
|
||||||
with open(
|
|
||||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
|
||||||
) as f_out:
|
|
||||||
f_out.write(config_yaml)
|
|
||||||
run_folder = "/workspace/artifacts/axolotl"
|
|
||||||
run_cmd(
|
|
||||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
|
||||||
run_folder,
|
|
||||||
volumes,
|
|
||||||
)
|
|
||||||
@@ -13,7 +13,6 @@ from axolotl.cli.utils import (
|
|||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
@@ -26,21 +25,15 @@ def cli():
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
|
||||||
@add_options_from_dataclass(PreprocessCliArgs)
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
def preprocess(config: str, **kwargs):
|
||||||
"""Preprocess datasets before training."""
|
"""Preprocess datasets before training."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
if cloud:
|
from axolotl.cli.preprocess import do_cli
|
||||||
from axolotl.cli.cloud import do_cli_preprocess
|
|
||||||
|
|
||||||
do_cli_preprocess(cloud_config=cloud, config=config)
|
do_cli(config=config, **kwargs)
|
||||||
else:
|
|
||||||
from axolotl.cli.preprocess import do_cli
|
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@@ -50,33 +43,25 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
|||||||
default=True,
|
default=True,
|
||||||
help="Use accelerate launch for multi-GPU training",
|
help="Use accelerate launch for multi-GPU training",
|
||||||
)
|
)
|
||||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def train(config: str, accelerate: bool, cloud: Optional[str], **kwargs):
|
def train(config: str, accelerate: bool, **kwargs):
|
||||||
"""Train or fine-tune a model."""
|
"""Train or fine-tune a model."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
from axolotl.cli.cloud import do_cli_train
|
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
if cloud:
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
if config:
|
||||||
else:
|
base_cmd.append(config)
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
cmd = build_command(base_cmd, kwargs)
|
||||||
if config:
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
base_cmd.append(config)
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
|
||||||
else:
|
else:
|
||||||
if cloud:
|
from axolotl.cli.train import do_cli
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
|
||||||
else:
|
|
||||||
from axolotl.cli.train import do_cli
|
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@@ -108,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--accelerate/--no-accelerate",
|
"--accelerate/--no-accelerate",
|
||||||
default=False,
|
default=True,
|
||||||
help="Use accelerate launch for multi-GPU inference",
|
help="Use accelerate launch for multi-GPU inference",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@@ -139,7 +124,7 @@ def inference(
|
|||||||
if lora_model_dir:
|
if lora_model_dir:
|
||||||
kwargs["lora_model_dir"] = lora_model_dir
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
if base_model:
|
if base_model:
|
||||||
kwargs["base_model"] = base_model
|
kwargs["output_dir"] = base_model
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||||
@@ -269,9 +254,6 @@ def fetch(directory: str, dest: Optional[str]):
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
@@ -244,6 +244,10 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
)
|
)
|
||||||
|
lr_groups: Optional[list[dict]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||||
|
)
|
||||||
embedding_lr: Optional[float] = field(
|
embedding_lr: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
@@ -462,11 +466,96 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
if lr_groups_lookup and any(
|
||||||
|
group_modules in name for group_modules in lr_groups_lookup
|
||||||
|
):
|
||||||
|
lr_group_module = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
][0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in [
|
not in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
@@ -480,59 +569,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
for name, param in opt_model.named_parameters():
|
opt_model, optimizer_kwargs
|
||||||
if not param.requires_grad:
|
)
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
@@ -549,6 +592,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif (
|
elif (
|
||||||
self.args.embedding_lr_scale is not None
|
self.args.embedding_lr_scale is not None
|
||||||
or self.args.embedding_lr is not None
|
or self.args.embedding_lr is not None
|
||||||
|
or self.args.lr_groups is not None
|
||||||
):
|
):
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
@@ -1764,6 +1808,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
@@ -1834,8 +1879,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
cfg=self.cfg,
|
self.cfg.chat_template,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
Module for the Plugin for LM Eval Harness
|
Module for the Plugin for LM Eval Harness
|
||||||
"""
|
"""
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
|
||||||
|
|
||||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
@@ -18,19 +18,25 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
if cfg.lm_eval_post_train:
|
tasks = ",".join(cfg.lm_eval_tasks)
|
||||||
# pylint: disable=duplicate-code
|
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||||
for lm_eval_args in build_lm_eval_command(
|
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||||
cfg.lm_eval_tasks,
|
output_path = cfg.output_dir
|
||||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||||
flash_attention=cfg.flash_attention,
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
output_dir=cfg.output_dir,
|
subprocess.run( # nosec
|
||||||
batch_size=cfg.lm_eval_batch_size,
|
[
|
||||||
wandb_project=cfg.wandb_project,
|
"lm_eval",
|
||||||
wandb_entity=cfg.wandb_entity,
|
"--model",
|
||||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
"hf",
|
||||||
):
|
"--model_args",
|
||||||
subprocess.run( # nosec
|
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||||
lm_eval_args,
|
"--tasks",
|
||||||
check=True,
|
tasks,
|
||||||
)
|
"--batch_size",
|
||||||
|
str(cfg.lm_eval_batch_size),
|
||||||
|
"--output_path",
|
||||||
|
output_path,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,5 +13,3 @@ class LMEvalArgs(BaseModel):
|
|||||||
|
|
||||||
lm_eval_tasks: List[str] = []
|
lm_eval_tasks: List[str] = []
|
||||||
lm_eval_batch_size: Optional[int] = 8
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
lm_eval_post_train: Optional[bool] = True
|
|
||||||
lm_eval_model: Optional[str] = None
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
"""
|
|
||||||
axolotl CLI for running lm_eval tasks
|
|
||||||
"""
|
|
||||||
import subprocess # nosec
|
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
def build_lm_eval_command(
|
|
||||||
tasks: list[str],
|
|
||||||
bfloat16=True,
|
|
||||||
flash_attention=False,
|
|
||||||
output_dir="./",
|
|
||||||
batch_size=8,
|
|
||||||
wandb_project=None,
|
|
||||||
wandb_entity=None,
|
|
||||||
model=None,
|
|
||||||
revision=None,
|
|
||||||
apply_chat_template=None,
|
|
||||||
fewshot_as_multiturn=None,
|
|
||||||
):
|
|
||||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
|
||||||
for task in tasks:
|
|
||||||
num_fewshot = "-1"
|
|
||||||
task_parts = task.split(":")
|
|
||||||
task_name = task_parts[0]
|
|
||||||
if len(task_parts) == 2:
|
|
||||||
task_name, num_fewshot = task_parts
|
|
||||||
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
|
||||||
|
|
||||||
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
|
||||||
tasks_str = ",".join(tasks_list)
|
|
||||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
|
||||||
pretrained = "pretrained="
|
|
||||||
pretrained += model if model else output_dir
|
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
|
||||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
|
||||||
revision = f",revision={revision}" if revision else ""
|
|
||||||
output_path = output_dir
|
|
||||||
output_path += "" if output_dir.endswith("/") else "/"
|
|
||||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
lm_eval_args = [
|
|
||||||
"lm_eval",
|
|
||||||
"--model",
|
|
||||||
"hf",
|
|
||||||
"--model_args",
|
|
||||||
f"{pretrained}{fa2}{dtype}{revision}",
|
|
||||||
"--tasks",
|
|
||||||
tasks_str,
|
|
||||||
"--batch_size",
|
|
||||||
str(batch_size),
|
|
||||||
"--output_path",
|
|
||||||
output_path,
|
|
||||||
]
|
|
||||||
wandb_args = []
|
|
||||||
if wandb_project:
|
|
||||||
wandb_args.append(f"project={wandb_project}")
|
|
||||||
if wandb_entity:
|
|
||||||
wandb_args.append(f"entity={wandb_entity}")
|
|
||||||
if wandb_args:
|
|
||||||
lm_eval_args.append("--wandb_args")
|
|
||||||
lm_eval_args.append(",".join(wandb_args))
|
|
||||||
if apply_chat_template:
|
|
||||||
lm_eval_args.append("--apply_chat_template")
|
|
||||||
if num_fewshot_val:
|
|
||||||
lm_eval_args.append("--num_fewshot")
|
|
||||||
lm_eval_args.append(str(num_fewshot_val))
|
|
||||||
if apply_chat_template and fewshot_as_multiturn:
|
|
||||||
lm_eval_args.append("--fewshot_as_multiturn")
|
|
||||||
|
|
||||||
yield lm_eval_args
|
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
|
||||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
|
||||||
def lm_eval(config: str, cloud: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
use lm eval to evaluate a trained language model
|
|
||||||
"""
|
|
||||||
|
|
||||||
if cloud:
|
|
||||||
from axolotl.cli.cloud import do_cli_lm_eval
|
|
||||||
|
|
||||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
|
||||||
else:
|
|
||||||
with open(config, encoding="utf-8") as file:
|
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
for lm_eval_args in build_lm_eval_command(
|
|
||||||
cfg.lm_eval_tasks,
|
|
||||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
|
||||||
flash_attention=cfg.flash_attention,
|
|
||||||
output_dir=cfg.output_dir,
|
|
||||||
batch_size=cfg.lm_eval_batch_size,
|
|
||||||
wandb_project=cfg.wandb_project,
|
|
||||||
wandb_entity=cfg.wandb_entity,
|
|
||||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
|
||||||
revision=cfg.revision,
|
|
||||||
apply_chat_template=cfg.apply_chat_template,
|
|
||||||
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
|
|
||||||
):
|
|
||||||
subprocess.run( # nosec
|
|
||||||
lm_eval_args,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -127,20 +126,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.fix_untrained_tokens:
|
if cfg.fix_untrained_tokens:
|
||||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
|
||||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
||||||
cfg.fix_untrained_tokens, list
|
|
||||||
):
|
|
||||||
fix_untrained_tokens(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
|||||||
@@ -145,6 +145,14 @@ class UserDefinedPrompterType(BaseModel):
|
|||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: List[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
class SFTDataset(BaseModel):
|
||||||
"""SFT configuration subset"""
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
@@ -466,6 +474,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
cosine_constant_lr_ratio: Optional[float] = None
|
cosine_constant_lr_ratio: Optional[float] = None
|
||||||
lr_div_factor: Optional[float] = None
|
lr_div_factor: Optional[float] = None
|
||||||
|
lr_groups: Optional[List[LrGroup]] = None
|
||||||
|
|
||||||
adam_epsilon: Optional[float] = None
|
adam_epsilon: Optional[float] = None
|
||||||
adam_beta1: Optional[float] = None
|
adam_beta1: Optional[float] = None
|
||||||
@@ -794,7 +803,7 @@ class AxolotlInputConfig(
|
|||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|||||||
@@ -28,10 +28,8 @@ def encode_pretraining(
|
|||||||
)
|
)
|
||||||
# Convert to PyTorch tensors
|
# Convert to PyTorch tensors
|
||||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
|
||||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||||
new_input_ids = []
|
new_input_ids = []
|
||||||
new_labels = []
|
|
||||||
new_attention_mask = []
|
new_attention_mask = []
|
||||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||||
for i, _ in enumerate(input_ids):
|
for i, _ in enumerate(input_ids):
|
||||||
@@ -42,34 +40,22 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
targets[i] = torch.cat(
|
|
||||||
(
|
|
||||||
targets[i],
|
|
||||||
torch.tensor([tokenizer.eos_token_id, -100]),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||||
|
|
||||||
# Concatenate tokens so that their lengths are less than max_tokens
|
# Concatenate tokens so that their lengths are less than max_tokens
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
for ids, labels, mask in zip(input_ids, targets, attention_mask):
|
for ids, mask in zip(input_ids, attention_mask):
|
||||||
if buffer_input_ids.numel() == max_tokens:
|
if buffer_input_ids.numel() == max_tokens:
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
else:
|
else:
|
||||||
buffer_input_ids = torch.cat(
|
buffer_input_ids = torch.cat(
|
||||||
@@ -83,17 +69,6 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
buffer_labels = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_labels,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_labels.numel(),),
|
|
||||||
-100,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
buffer_attention_mask = torch.cat(
|
||||||
(
|
(
|
||||||
buffer_attention_mask,
|
buffer_attention_mask,
|
||||||
@@ -106,14 +81,11 @@ def encode_pretraining(
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
buffer_labels = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
|
|
||||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||||
@@ -129,17 +101,6 @@ def encode_pretraining(
|
|||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
buffer_labels = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_labels,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_labels.numel(),),
|
|
||||||
-100,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
buffer_attention_mask = torch.cat(
|
||||||
(
|
(
|
||||||
buffer_attention_mask,
|
buffer_attention_mask,
|
||||||
@@ -152,12 +113,11 @@ def encode_pretraining(
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
new_input_ids.append(buffer_input_ids)
|
new_input_ids.append(buffer_input_ids)
|
||||||
new_labels.append(buffer_labels)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||||
"labels": [seq.tolist() for seq in new_labels],
|
"labels": [seq.tolist() for seq in new_input_ids],
|
||||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user