Compare commits
13 Commits
hymba_mult
...
cli-cloud-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee20600b9a | ||
|
|
fd91de3ea6 | ||
|
|
530bf77cf9 | ||
|
|
bfc91a91ca | ||
|
|
5c226b600d | ||
|
|
af66f7c274 | ||
|
|
079f94ee99 | ||
|
|
981ad965d0 | ||
|
|
7ba701a355 | ||
|
|
0390bce7aa | ||
|
|
2741d8de23 | ||
|
|
27a88f37cd | ||
|
|
6da8abc01f |
@@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
15
examples/cloud/modal.yaml
Normal file
15
examples/cloud/modal.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
volumes:
|
||||||
|
- name: axolotl-data
|
||||||
|
mount: /workspace/data
|
||||||
|
- name: axolotl-artifacts
|
||||||
|
mount: /workspace/artifacts
|
||||||
|
secrets:
|
||||||
|
- HF_TOKEN
|
||||||
|
- WANDB_API_KEY
|
||||||
|
branch: cli-cloud-modal
|
||||||
|
gpu: h100
|
||||||
|
gpu_count: 1
|
||||||
|
memory: 128
|
||||||
|
timeout: 86400
|
||||||
|
timeout_preprocess: 14400
|
||||||
|
memory_preprocess: 32
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
base_model: nvidia/Hymba-1.5B-Base
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 5
|
|
||||||
evals_per_epoch: 2
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
base_model: nvidia/Hymba-1.5B-Base
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: True
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 5
|
|
||||||
evals_per_epoch: 2
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
11
lm_eval-kd.yaml
Normal file
11
lm_eval-kd.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
lm_eval_model: axolotl-ai-co/numina-8b-ep1-exp1
|
||||||
|
lm_eval_tasks:
|
||||||
|
- leaderboard_math_hard
|
||||||
|
lm_eval_batch_size: 64
|
||||||
|
|
||||||
|
apply_chat_template: false
|
||||||
|
wandb_project: numina-kd-experiment
|
||||||
|
wandb_entity: axolotl-ai
|
||||||
|
bf16: true
|
||||||
|
flash_attention: true
|
||||||
|
output_dir: ./outputs/model-evals-out
|
||||||
@@ -25,6 +25,7 @@ hf_transfer
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
|
|
||||||
|
modal==0.70.5
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
@@ -53,7 +54,7 @@ zstandard==0.22.0
|
|||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
# lm eval harness
|
# lm eval harness
|
||||||
lm_eval==0.4.4
|
lm_eval==0.4.7
|
||||||
langdetect==1.0.9
|
langdetect==1.0.9
|
||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|||||||
17
scripts/motd
17
scripts/motd
@@ -1,10 +1,15 @@
|
|||||||
|
|
||||||
dP dP dP
|
#@@ #@@ @@# @@#
|
||||||
88 88 88
|
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||||
|
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||||
|
@@@@ @@@@@@@@@@@@@@@@
|
||||||
|
|
||||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||||
|
|
||||||
|
|||||||
56
src/axolotl/cli/cloud/__init__.py
Normal file
56
src/axolotl/cli/cloud/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
launch axolotl in supported cloud platforms
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli import print_axolotl_text_art
|
||||||
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||||
|
"""Load and validate cloud configuration."""
|
||||||
|
# Load cloud configuration.
|
||||||
|
with open(cloud_config, encoding="utf-8") as file:
|
||||||
|
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
return cloud_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli_preprocess(
|
||||||
|
cloud_config: Union[Path, str],
|
||||||
|
config: Union[Path, str] = Path("examples/"),
|
||||||
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
config_yaml = file.read()
|
||||||
|
cloud.preprocess(config_yaml)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli_train(
|
||||||
|
cloud_config: Union[Path, str],
|
||||||
|
config: Union[Path, str] = Path("examples/"),
|
||||||
|
accelerate: bool = True,
|
||||||
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
config_yaml = file.read()
|
||||||
|
cloud.train(config_yaml, accelerate=accelerate)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli_lm_eval(
|
||||||
|
cloud_config: Union[Path, str],
|
||||||
|
config: Union[Path, str] = Path("examples/"),
|
||||||
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
config_yaml = file.read()
|
||||||
|
cloud.lm_eval(config_yaml)
|
||||||
18
src/axolotl/cli/cloud/base.py
Normal file
18
src/axolotl/cli/cloud/base.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
base class for cloud platforms from cli
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Cloud(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for cloud platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||||
|
pass
|
||||||
272
src/axolotl/cli/cloud/modal_.py
Normal file
272
src/axolotl/cli/cloud/modal_.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
Modal Cloud support from CLI
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess # nosec B404
|
||||||
|
from pathlib import Path
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
import modal
|
||||||
|
|
||||||
|
from axolotl.cli.cloud.base import Cloud
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||||
|
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
|
||||||
|
# Ensure volumes contain latest files.
|
||||||
|
if volumes:
|
||||||
|
for _, vol in volumes.items():
|
||||||
|
vol.reload()
|
||||||
|
|
||||||
|
# modal workaround so it doesn't use the automounted axolotl
|
||||||
|
new_env = copy.deepcopy(os.environ)
|
||||||
|
if "PYTHONPATH" in new_env:
|
||||||
|
del new_env["PYTHONPATH"]
|
||||||
|
|
||||||
|
# Propagate errors from subprocess.
|
||||||
|
if exit_code := subprocess.call( # nosec B603
|
||||||
|
cmd.split(), cwd=run_folder, env=new_env
|
||||||
|
):
|
||||||
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
# Commit writes to volume.
|
||||||
|
if volumes:
|
||||||
|
for _, vol in volumes.items():
|
||||||
|
vol.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class ModalCloud(Cloud):
|
||||||
|
"""
|
||||||
|
Modal Cloud implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, app=None):
|
||||||
|
self.config = config
|
||||||
|
if not app:
|
||||||
|
app = modal.App()
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
self.volumes = {}
|
||||||
|
if config.volumes:
|
||||||
|
for volume_config in config.volumes:
|
||||||
|
_, mount, vol = self.create_volume(volume_config)
|
||||||
|
self.volumes[mount] = (vol, volume_config)
|
||||||
|
|
||||||
|
def get_env(self):
|
||||||
|
res = {
|
||||||
|
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
|
||||||
|
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in self.config.get("env", []):
|
||||||
|
if isinstance(key, str):
|
||||||
|
if val := os.environ.get(key, ""):
|
||||||
|
res[key] = val
|
||||||
|
elif isinstance(key, dict):
|
||||||
|
(key_, val) = list(key.items())[0]
|
||||||
|
res[key_] = val
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
docker_tag = "main-py3.11-cu124-2.5.1"
|
||||||
|
if self.config.docker_tag:
|
||||||
|
docker_tag = self.config.docker_tag
|
||||||
|
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||||
|
|
||||||
|
# grab the sha256 hash from docker hub for this image+tag
|
||||||
|
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||||
|
try:
|
||||||
|
manifest = subprocess.check_output( # nosec B602
|
||||||
|
f"docker manifest inspect {docker_image}",
|
||||||
|
shell=True,
|
||||||
|
).decode("utf-8")
|
||||||
|
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
sha256_hash = None
|
||||||
|
|
||||||
|
# create the image
|
||||||
|
if sha256_hash:
|
||||||
|
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
|
||||||
|
else:
|
||||||
|
image = modal.Image.from_registry(docker_image)
|
||||||
|
|
||||||
|
# branch
|
||||||
|
if self.config.branch:
|
||||||
|
image = image.dockerfile_commands(
|
||||||
|
[
|
||||||
|
# Random id for cache busting of branch commits
|
||||||
|
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||||
|
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||||
|
"RUN cd /workspace/ && git clone https://github.com/winglian/lm-evaluation-harness.git && cd lm-evaluation-harness && pip install -e .[math]",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if env := self.get_env():
|
||||||
|
image = image.env(env)
|
||||||
|
|
||||||
|
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_secrets(self):
|
||||||
|
res = []
|
||||||
|
if self.config.secrets:
|
||||||
|
for key in self.config.get("secrets", []):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if isinstance(key, str):
|
||||||
|
if val := os.environ.get(key, ""):
|
||||||
|
res.append(modal.Secret.from_dict({key: val}))
|
||||||
|
elif isinstance(key, dict):
|
||||||
|
(key_, val) = list(key.items())[0]
|
||||||
|
res.append(modal.Secret.from_dict({key_: val}))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def create_volume(self, volume_config):
|
||||||
|
name = volume_config.name
|
||||||
|
mount = volume_config.mount
|
||||||
|
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
|
||||||
|
|
||||||
|
def get_ephemeral_disk_size(self):
|
||||||
|
return 1000 * 525 # 1 TiB
|
||||||
|
|
||||||
|
def get_preprocess_timeout(self):
|
||||||
|
if self.config.timeout_preprocess:
|
||||||
|
return int(self.config.timeout_preprocess)
|
||||||
|
return 60 * 60 * 3 # 3 hours
|
||||||
|
|
||||||
|
def get_preprocess_memory(self):
|
||||||
|
memory = 128 # default to 128GiB
|
||||||
|
if self.config.memory:
|
||||||
|
memory = int(self.config.memory)
|
||||||
|
if self.config.memory_preprocess:
|
||||||
|
memory = int(self.config.memory_preprocess)
|
||||||
|
return 1024 * memory
|
||||||
|
|
||||||
|
def get_preprocess_env(self):
|
||||||
|
return self.app.function(
|
||||||
|
image=self.get_image(),
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
cpu=8.0,
|
||||||
|
ephemeral_disk=self.get_ephemeral_disk_size(),
|
||||||
|
memory=self.get_preprocess_memory(),
|
||||||
|
timeout=self.get_preprocess_timeout(),
|
||||||
|
secrets=self.get_secrets(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(self, config_yaml: str, *args, **kwargs):
|
||||||
|
modal_fn = self.get_preprocess_env()(_preprocess)
|
||||||
|
with modal.enable_output():
|
||||||
|
with self.app.run(detach=True):
|
||||||
|
modal_fn.remote(
|
||||||
|
config_yaml,
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_train_timeout(self):
|
||||||
|
if self.config.timeout:
|
||||||
|
return int(self.config.timeout)
|
||||||
|
return 60 * 60 * 24 # 24 hours
|
||||||
|
|
||||||
|
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
||||||
|
count = self.config.gpu_count or 1
|
||||||
|
family = self.config.gpu.lower() or "l40s"
|
||||||
|
|
||||||
|
if family == "l40s":
|
||||||
|
return modal.gpu.L40S(count=count)
|
||||||
|
if family == "a100":
|
||||||
|
return modal.gpu.A100(count=count, size="40GB")
|
||||||
|
if family == "a100-80gb":
|
||||||
|
return modal.gpu.A100(count=count, size="80GB")
|
||||||
|
if family in ["a10", "a10g"]:
|
||||||
|
return modal.gpu.A10G(count=count)
|
||||||
|
if family == "h100":
|
||||||
|
return modal.gpu.H100(count=count)
|
||||||
|
if family == "t4":
|
||||||
|
return modal.gpu.T4(count=count)
|
||||||
|
if family == "l4":
|
||||||
|
return modal.gpu.L4(count=count)
|
||||||
|
raise ValueError(f"Unsupported GPU family: {family}")
|
||||||
|
|
||||||
|
def get_train_memory(self):
|
||||||
|
memory = 128 # default to 128GiB
|
||||||
|
if self.config.memory:
|
||||||
|
memory = int(self.config.memory)
|
||||||
|
return 1024 * memory
|
||||||
|
|
||||||
|
def get_train_env(self):
|
||||||
|
return self.app.function(
|
||||||
|
image=self.get_image(),
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
cpu=16.0,
|
||||||
|
gpu=self.get_train_gpu(),
|
||||||
|
memory=self.get_train_memory(),
|
||||||
|
timeout=self.get_train_timeout(),
|
||||||
|
secrets=self.get_secrets(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(self, config_yaml: str, accelerate: bool = True):
|
||||||
|
modal_fn = self.get_train_env()(_train)
|
||||||
|
with modal.enable_output():
|
||||||
|
with self.app.run(detach=True):
|
||||||
|
modal_fn.remote(
|
||||||
|
config_yaml,
|
||||||
|
accelerate=accelerate,
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
)
|
||||||
|
|
||||||
|
def lm_eval(self, config_yaml: str):
|
||||||
|
modal_fn = self.get_train_env()(_lm_eval)
|
||||||
|
with modal.enable_output():
|
||||||
|
with self.app.run(detach=True):
|
||||||
|
modal_fn.remote(
|
||||||
|
config_yaml,
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess(config_yaml: str, volumes=None):
|
||||||
|
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(
|
||||||
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
|
) as f_out:
|
||||||
|
f_out.write(config_yaml)
|
||||||
|
run_folder = "/workspace/artifacts/axolotl"
|
||||||
|
run_cmd(
|
||||||
|
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||||
|
run_folder,
|
||||||
|
volumes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||||
|
with open(
|
||||||
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
|
) as f_out:
|
||||||
|
f_out.write(config_yaml)
|
||||||
|
run_folder = "/workspace/artifacts/axolotl"
|
||||||
|
if accelerate:
|
||||||
|
accelerate_args = "--accelerate"
|
||||||
|
else:
|
||||||
|
accelerate_args = "--no-accelerate"
|
||||||
|
run_cmd(
|
||||||
|
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||||
|
run_folder,
|
||||||
|
volumes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_eval(config_yaml: str, volumes=None):
|
||||||
|
with open(
|
||||||
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
|
) as f_out:
|
||||||
|
f_out.write(config_yaml)
|
||||||
|
run_folder = "/workspace/artifacts/axolotl"
|
||||||
|
run_cmd(
|
||||||
|
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||||
|
run_folder,
|
||||||
|
volumes,
|
||||||
|
)
|
||||||
@@ -13,6 +13,7 @@ from axolotl.cli.utils import (
|
|||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
@@ -25,15 +26,21 @@ def cli():
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(PreprocessCliArgs)
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def preprocess(config: str, **kwargs):
|
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||||
"""Preprocess datasets before training."""
|
"""Preprocess datasets before training."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
from axolotl.cli.preprocess import do_cli
|
if cloud:
|
||||||
|
from axolotl.cli.cloud import do_cli_preprocess
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli_preprocess(cloud_config=cloud, config=config)
|
||||||
|
else:
|
||||||
|
from axolotl.cli.preprocess import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@@ -43,25 +50,33 @@ def preprocess(config: str, **kwargs):
|
|||||||
default=True,
|
default=True,
|
||||||
help="Use accelerate launch for multi-GPU training",
|
help="Use accelerate launch for multi-GPU training",
|
||||||
)
|
)
|
||||||
|
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def train(config: str, accelerate: bool, **kwargs):
|
def train(config: str, accelerate: bool, cloud: Optional[str], **kwargs):
|
||||||
"""Train or fine-tune a model."""
|
"""Train or fine-tune a model."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
from axolotl.cli.cloud import do_cli_train
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
if cloud:
|
||||||
if config:
|
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||||
base_cmd.append(config)
|
else:
|
||||||
cmd = build_command(base_cmd, kwargs)
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
else:
|
else:
|
||||||
from axolotl.cli.train import do_cli
|
if cloud:
|
||||||
|
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||||
|
else:
|
||||||
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@@ -254,6 +269,9 @@ def fetch(directory: str, dest: Optional[str]):
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
Module for the Plugin for LM Eval Harness
|
Module for the Plugin for LM Eval Harness
|
||||||
"""
|
"""
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
||||||
|
|
||||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
@@ -18,25 +18,19 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
tasks = ",".join(cfg.lm_eval_tasks)
|
if cfg.lm_eval_post_train:
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
# pylint: disable=duplicate-code
|
||||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
for lm_eval_args in build_lm_eval_command(
|
||||||
output_path = cfg.output_dir
|
cfg.lm_eval_tasks,
|
||||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
flash_attention=cfg.flash_attention,
|
||||||
subprocess.run( # nosec
|
output_dir=cfg.output_dir,
|
||||||
[
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
"lm_eval",
|
wandb_project=cfg.wandb_project,
|
||||||
"--model",
|
wandb_entity=cfg.wandb_entity,
|
||||||
"hf",
|
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||||
"--model_args",
|
):
|
||||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
subprocess.run( # nosec
|
||||||
"--tasks",
|
lm_eval_args,
|
||||||
tasks,
|
check=True,
|
||||||
"--batch_size",
|
)
|
||||||
str(cfg.lm_eval_batch_size),
|
|
||||||
"--output_path",
|
|
||||||
output_path,
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
|
|||||||
|
|
||||||
lm_eval_tasks: List[str] = []
|
lm_eval_tasks: List[str] = []
|
||||||
lm_eval_batch_size: Optional[int] = 8
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
|
lm_eval_post_train: Optional[bool] = True
|
||||||
|
lm_eval_model: Optional[str] = None
|
||||||
|
|||||||
113
src/axolotl/integrations/lm_eval/cli.py
Normal file
113
src/axolotl/integrations/lm_eval/cli.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
axolotl CLI for running lm_eval tasks
|
||||||
|
"""
|
||||||
|
import subprocess # nosec
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def build_lm_eval_command(
|
||||||
|
tasks: list[str],
|
||||||
|
bfloat16=True,
|
||||||
|
flash_attention=False,
|
||||||
|
output_dir="./",
|
||||||
|
batch_size=8,
|
||||||
|
wandb_project=None,
|
||||||
|
wandb_entity=None,
|
||||||
|
model=None,
|
||||||
|
revision=None,
|
||||||
|
apply_chat_template=None,
|
||||||
|
fewshot_as_multiturn=None,
|
||||||
|
):
|
||||||
|
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||||
|
for task in tasks:
|
||||||
|
num_fewshot = "-1"
|
||||||
|
task_parts = task.split(":")
|
||||||
|
task_name = task_parts[0]
|
||||||
|
if len(task_parts) == 2:
|
||||||
|
task_name, num_fewshot = task_parts
|
||||||
|
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
||||||
|
|
||||||
|
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
||||||
|
tasks_str = ",".join(tasks_list)
|
||||||
|
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||||
|
pretrained = "pretrained="
|
||||||
|
pretrained += model if model else output_dir
|
||||||
|
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||||
|
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||||
|
revision = f",revision={revision}" if revision else ""
|
||||||
|
output_path = output_dir
|
||||||
|
output_path += "" if output_dir.endswith("/") else "/"
|
||||||
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
lm_eval_args = [
|
||||||
|
"lm_eval",
|
||||||
|
"--model",
|
||||||
|
"hf",
|
||||||
|
"--model_args",
|
||||||
|
f"{pretrained}{fa2}{dtype}{revision}",
|
||||||
|
"--tasks",
|
||||||
|
tasks_str,
|
||||||
|
"--batch_size",
|
||||||
|
str(batch_size),
|
||||||
|
"--output_path",
|
||||||
|
output_path,
|
||||||
|
]
|
||||||
|
wandb_args = []
|
||||||
|
if wandb_project:
|
||||||
|
wandb_args.append(f"project={wandb_project}")
|
||||||
|
if wandb_entity:
|
||||||
|
wandb_args.append(f"entity={wandb_entity}")
|
||||||
|
if wandb_args:
|
||||||
|
lm_eval_args.append("--wandb_args")
|
||||||
|
lm_eval_args.append(",".join(wandb_args))
|
||||||
|
if apply_chat_template:
|
||||||
|
lm_eval_args.append("--apply_chat_template")
|
||||||
|
if num_fewshot_val:
|
||||||
|
lm_eval_args.append("--num_fewshot")
|
||||||
|
lm_eval_args.append(str(num_fewshot_val))
|
||||||
|
if apply_chat_template and fewshot_as_multiturn:
|
||||||
|
lm_eval_args.append("--fewshot_as_multiturn")
|
||||||
|
|
||||||
|
yield lm_eval_args
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||||
|
def lm_eval(config: str, cloud: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
use lm eval to evaluate a trained language model
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cloud:
|
||||||
|
from axolotl.cli.cloud import do_cli_lm_eval
|
||||||
|
|
||||||
|
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||||
|
else:
|
||||||
|
with open(config, encoding="utf-8") as file:
|
||||||
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
for lm_eval_args in build_lm_eval_command(
|
||||||
|
cfg.lm_eval_tasks,
|
||||||
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
|
flash_attention=cfg.flash_attention,
|
||||||
|
output_dir=cfg.output_dir,
|
||||||
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
|
wandb_project=cfg.wandb_project,
|
||||||
|
wandb_entity=cfg.wandb_entity,
|
||||||
|
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||||
|
revision=cfg.revision,
|
||||||
|
apply_chat_template=cfg.apply_chat_template,
|
||||||
|
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
|
||||||
|
):
|
||||||
|
subprocess.run( # nosec
|
||||||
|
lm_eval_args,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
@@ -25,7 +25,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
"hymba",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ _CHAT_TEMPLATES = {
|
|||||||
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||||
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
|
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
|
||||||
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
|
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
"hymba": "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1629,19 +1629,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
else:
|
else:
|
||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_hymba_torch_version(cls, data):
|
|
||||||
if "hymba" in data.get("base_model", {}).lower():
|
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
|
||||||
torch_version = env_capabilities.get("torch_version")
|
|
||||||
|
|
||||||
if torch_version is None:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
|
||||||
|
|
||||||
if version.parse(torch_version) < version.parse("2.5.0"):
|
|
||||||
raise ValueError("Hymba requires torch version >= 2.5")
|
|
||||||
return data
|
|
||||||
|
|||||||
@@ -409,7 +409,6 @@ class ModelLoader:
|
|||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if "auto_map" in self.model_config:
|
if "auto_map" in self.model_config:
|
||||||
# some model config objects are not subscriptable
|
|
||||||
try:
|
try:
|
||||||
auto_map_config = self.model_config["auto_map"]
|
auto_map_config = self.model_config["auto_map"]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
@require_torch_2_5_1
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
def test_adopt_adamw(self, temp_dir):
|
def test_adopt_adamw(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_tensorboard, require_torch_2_5_1, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -68,129 +68,3 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestUnpackedHymba(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Unpacked training of hymba models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@require_torch_2_5_1
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_unpacked(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "nvidia/Hymba-1.5B-Base",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 32,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_modules": [
|
|
||||||
"gate_proj",
|
|
||||||
"down_proj",
|
|
||||||
"up_proj",
|
|
||||||
"q_proj",
|
|
||||||
"v_proj",
|
|
||||||
"k_proj",
|
|
||||||
"o_proj",
|
|
||||||
],
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": False,
|
|
||||||
"flash_attention": True,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPackedHymba(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Packed training of hymba models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@require_torch_2_5_1
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_packed(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "nvidia/Hymba-1.5B-Base",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 32,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_modules": [
|
|
||||||
"gate_proj",
|
|
||||||
"down_proj",
|
|
||||||
"up_proj",
|
|
||||||
"q_proj",
|
|
||||||
"v_proj",
|
|
||||||
"k_proj",
|
|
||||||
"o_proj",
|
|
||||||
],
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user