Compare commits
12 Commits
mm2
...
feature/en
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1f36d7b78 | ||
|
|
87248027d0 | ||
|
|
d0d22b7812 | ||
|
|
68db5b1b67 | ||
|
|
2fbc6b0c64 | ||
|
|
8159cbd1ab | ||
|
|
979534c851 | ||
|
|
6d3caadf90 | ||
|
|
dee77232fe | ||
|
|
a560593b1d | ||
|
|
e8d3da0081 | ||
|
|
4ca0a47cfb |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -28,7 +28,13 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "124"
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -84,7 +84,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/nightlies.yml
vendored
4
.github/workflows/nightlies.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -83,7 +83,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.0"]
|
pytorch_version: ["2.3.1", "2.4.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.0"]
|
pytorch_version: ["2.3.1", "2.4.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.0
|
pytorch: 2.4.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb
|
known_third_party=wandb,comet_ml
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -14,7 +14,7 @@ Features:
|
|||||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||||
- Easily run with Docker locally or on the cloud
|
- Easily run with Docker locally or on the cloud
|
||||||
- Log results and optionally checkpoints to wandb or mlflow
|
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||||
- And more!
|
- And more!
|
||||||
|
|
||||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||||
@@ -515,6 +515,22 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### Comet Logging
|
||||||
|
|
||||||
|
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
||||||
|
|
||||||
|
- wandb options
|
||||||
|
```yaml
|
||||||
|
use_comet:
|
||||||
|
comet_api_key:
|
||||||
|
comet_workspace:
|
||||||
|
comet_project_name:
|
||||||
|
comet_experiment_key:
|
||||||
|
comet_mode:
|
||||||
|
comet_online:
|
||||||
|
comet_experiment_config:
|
||||||
|
```
|
||||||
|
|
||||||
##### Special Tokens
|
##### Special Tokens
|
||||||
|
|
||||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ datasets:
|
|||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
@@ -267,6 +268,18 @@ mlflow_tracking_uri: # URI to mlflow
|
|||||||
mlflow_experiment_name: # Your experiment name
|
mlflow_experiment_name: # Your experiment name
|
||||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||||
|
|
||||||
|
# Comet configuration if you're using it
|
||||||
|
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||||
|
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||||
|
use_comet: # Enable or disable Comet integration.
|
||||||
|
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||||
|
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||||
|
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||||
|
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||||
|
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||||
|
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||||
|
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ flash-attn==2.6.3
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.27
|
xformers==0.0.28.post1
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -46,3 +46,9 @@ gcsfs>=2024.5.0
|
|||||||
trl==0.9.6
|
trl==0.9.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
# lm eval harness
|
||||||
|
lm_eval==0.4.4
|
||||||
|
langdetect==1.0.9
|
||||||
|
immutabledict==4.2.0
|
||||||
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -49,10 +49,17 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 4):
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 3):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.26.post1")
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
@@ -54,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts")
|
|||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
AXOLOTL_LOGO = """
|
||||||
|
#@@ #@@ @@# @@#
|
||||||
|
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||||
|
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||||
|
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||||
|
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||||
|
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||||
|
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||||
|
@@@@ @@@@@@@@@@@@@@@@
|
||||||
|
"""
|
||||||
|
|
||||||
def print_axolotl_text_art(suffix=None):
|
|
||||||
|
def print_legacy_axolotl_text_art(suffix=None):
|
||||||
font = "nancyj"
|
font = "nancyj"
|
||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
@@ -68,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
print_dep_versions()
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def print_axolotl_text_art(
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
print(AXOLOTL_LOGO)
|
||||||
|
|
||||||
|
|
||||||
def print_dep_versions():
|
def print_dep_versions():
|
||||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
max_len = max(len(pkg) for pkg in packages)
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
@@ -421,6 +443,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
|
|
||||||
|
setup_comet_env_vars(cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -20,6 +18,7 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
register_llama3_template,
|
register_llama3_template,
|
||||||
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return do_train(parsed_cfg, parsed_cli_args)
|
return do_train(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
def do_train(cfg, cli_args) -> None:
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length
|
|||||||
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
@@ -1111,6 +1111,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_comet and is_comet_available():
|
||||||
|
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1179,6 +1185,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer, self.tokenizer, "mlflow"
|
trainer, self.tokenizer, "mlflow"
|
||||||
)
|
)
|
||||||
callbacks.append(LogPredictionCallback(self.cfg))
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
|
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||||
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
|
trainer, self.tokenizer, "comet_ml"
|
||||||
|
)
|
||||||
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||||
@@ -1430,6 +1441,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
report_to.append("mlflow")
|
report_to.append("mlflow")
|
||||||
if self.cfg.use_tensorboard:
|
if self.cfg.use_tensorboard:
|
||||||
report_to.append("tensorboard")
|
report_to.append("tensorboard")
|
||||||
|
if self.cfg.use_comet:
|
||||||
|
report_to.append("comet_ml")
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
|
|||||||
@@ -159,6 +159,29 @@ class BasePlugin:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def post_train(self, cfg, model):
|
||||||
|
"""
|
||||||
|
Performs actions after training is complete.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The axolotl configuration
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
"""
|
||||||
|
Performs actions after training is complete and the model is unloaded.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def load_plugin(plugin_name: str) -> BasePlugin:
|
def load_plugin(plugin_name: str) -> BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -381,3 +404,17 @@ class PluginManager:
|
|||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
"""
|
||||||
|
Calls the post_train_unload method of all registered plugins.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins:
|
||||||
|
plugin.post_train_unload(cfg)
|
||||||
|
|||||||
13
src/axolotl/integrations/lm_eval/README.md
Normal file
13
src/axolotl/integrations/lm_eval/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# LM Eval Harness
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.lm_eval.LMEvalPlugin
|
||||||
|
|
||||||
|
lm_eval_tasks:
|
||||||
|
- gsm8k
|
||||||
|
- hellaswag
|
||||||
|
- arc_easy
|
||||||
|
```
|
||||||
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
Module for the Plugin for LM Eval Harness
|
||||||
|
"""
|
||||||
|
import subprocess # nosec
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class LMEvalPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for LM Evaluation Harness integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
tasks = ",".join(cfg.lm_eval_tasks)
|
||||||
|
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||||
|
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||||
|
output_path = cfg.output_dir
|
||||||
|
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||||
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
subprocess.run( # nosec
|
||||||
|
[
|
||||||
|
"lm_eval",
|
||||||
|
"--model",
|
||||||
|
"hf",
|
||||||
|
"--model_args",
|
||||||
|
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||||
|
"--tasks",
|
||||||
|
tasks,
|
||||||
|
"--batch_size",
|
||||||
|
str(cfg.lm_eval_batch_size),
|
||||||
|
"--output_path",
|
||||||
|
output_path,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
15
src/axolotl/integrations/lm_eval/args.py
Normal file
15
src/axolotl/integrations/lm_eval/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Module for handling lm eval harness input arguments.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LMEvalArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for lm eval harness
|
||||||
|
"""
|
||||||
|
|
||||||
|
lm_eval_tasks: List[str] = []
|
||||||
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
|||||||
def reset_optimizer(
|
def reset_optimizer(
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
*,
|
*,
|
||||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||||
optimizer_state_keys: list[str],
|
optimizer_state_keys: List[str],
|
||||||
prune_ratio: float = 0.9,
|
prune_ratio: float = 0.9,
|
||||||
):
|
):
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
Basic utils for Axolotl
|
Basic utils for Axolotl
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
return importlib.util.find_spec("mlflow") is not None
|
return importlib.util.find_spec("mlflow") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_comet_available():
|
||||||
|
return importlib.util.find_spec("comet_ml") is not None
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils import is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
references=[[r] for r in references],
|
references=[[r] for r in references],
|
||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
)
|
)
|
||||||
scores[metric_name] = score
|
scores["eval_" + metric_name] = score
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def predict_with_generate():
|
def predict_with_generate():
|
||||||
@@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
|||||||
artifact_file="PredictionsVsGroundTruth.json",
|
artifact_file="PredictionsVsGroundTruth.json",
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
)
|
)
|
||||||
|
elif logger == "comet_ml" and is_comet_available():
|
||||||
|
import comet_ml
|
||||||
|
|
||||||
|
experiment = comet_ml.get_running_experiment()
|
||||||
|
if experiment:
|
||||||
|
experiment.log_table(
|
||||||
|
f"{name} - Predictions vs Ground Truth.csv",
|
||||||
|
pd.DataFrame(table_data),
|
||||||
|
)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
log_table_from_dataloader("Eval", eval_dataloader)
|
log_table_from_dataloader("Eval", eval_dataloader)
|
||||||
|
|||||||
43
src/axolotl/utils/callbacks/comet_.py
Normal file
43
src/axolotl/utils/callbacks/comet_.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Comet module for trainer callbacks"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import comet_ml
|
||||||
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||||
|
"""Callback to save axolotl config to comet"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
try:
|
||||||
|
comet_experiment = comet_ml.start(source="axolotl")
|
||||||
|
comet_experiment.log_other("Created from", "axolotl")
|
||||||
|
comet_experiment.log_asset(
|
||||||
|
self.axolotl_config_path,
|
||||||
|
file_name="axolotl-config",
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the Comet Experiment under assets."
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
||||||
|
return control
|
||||||
@@ -5,7 +5,9 @@ These templates are used for formatting messages in a conversation.
|
|||||||
|
|
||||||
CHAT_TEMPLATES = {
|
CHAT_TEMPLATES = {
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
||||||
|
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
||||||
|
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
|||||||
93
src/axolotl/utils/comet_.py
Normal file
93
src/axolotl/utils/comet_.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Module for wandb utilities"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||||
|
|
||||||
|
COMET_ENV_MAPPING_OVERRIDE = {
|
||||||
|
"comet_mode": "COMET_START_MODE",
|
||||||
|
"comet_online": "COMET_START_ONLINE",
|
||||||
|
}
|
||||||
|
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
||||||
|
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
||||||
|
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
||||||
|
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
||||||
|
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
||||||
|
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
||||||
|
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
||||||
|
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
||||||
|
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
||||||
|
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
||||||
|
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
||||||
|
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
||||||
|
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
||||||
|
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
||||||
|
"log_code": "COMET_AUTO_LOG_CODE",
|
||||||
|
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
||||||
|
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
||||||
|
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
||||||
|
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
||||||
|
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
||||||
|
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
||||||
|
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
||||||
|
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
||||||
|
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
||||||
|
"name": "COMET_START_EXPERIMENT_NAME",
|
||||||
|
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
||||||
|
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
||||||
|
"tags": "COMET_START_EXPERIMENT_TAGS",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def python_value_to_environ_value(python_value):
|
||||||
|
if isinstance(python_value, bool):
|
||||||
|
if python_value is True:
|
||||||
|
return "true"
|
||||||
|
|
||||||
|
return "false"
|
||||||
|
|
||||||
|
if isinstance(python_value, int):
|
||||||
|
return str(python_value)
|
||||||
|
|
||||||
|
if isinstance(python_value, list): # Comet only have one list of string parameter
|
||||||
|
return ",".join(map(str, python_value))
|
||||||
|
|
||||||
|
return python_value
|
||||||
|
|
||||||
|
|
||||||
|
def setup_comet_env_vars(cfg: DictDefault):
|
||||||
|
# TODO, we need to convert Axolotl configuration to environment variables
|
||||||
|
# as Transformers integration are call first and would create an
|
||||||
|
# Experiment first
|
||||||
|
|
||||||
|
for key in cfg.keys():
|
||||||
|
if key.startswith("comet_") and key != "comet_experiment_config":
|
||||||
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value is not None and value != "":
|
||||||
|
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
||||||
|
final_value = python_value_to_environ_value(value)
|
||||||
|
os.environ[env_variable_name] = final_value
|
||||||
|
|
||||||
|
if cfg.comet_experiment_config:
|
||||||
|
for key, value in cfg.comet_experiment_config.items():
|
||||||
|
if value is not None and value != "":
|
||||||
|
config_env_variable_name = (
|
||||||
|
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_env_variable_name is None:
|
||||||
|
LOG.warning(
|
||||||
|
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
final_value = python_value_to_environ_value(value)
|
||||||
|
os.environ[config_env_variable_name] = final_value
|
||||||
|
|
||||||
|
# Enable comet if project name is present
|
||||||
|
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
||||||
|
cfg.use_comet = True
|
||||||
@@ -489,6 +489,19 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CometConfig(BaseModel):
|
||||||
|
"""Comet configuration subset"""
|
||||||
|
|
||||||
|
use_comet: Optional[bool] = None
|
||||||
|
comet_api_key: Optional[str] = None
|
||||||
|
comet_workspace: Optional[str] = None
|
||||||
|
comet_project_name: Optional[str] = None
|
||||||
|
comet_experiment_key: Optional[str] = None
|
||||||
|
comet_mode: Optional[str] = None
|
||||||
|
comet_online: Optional[bool] = None
|
||||||
|
comet_experiment_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class GradioConfig(BaseModel):
|
class GradioConfig(BaseModel):
|
||||||
"""Gradio configuration subset"""
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
@@ -509,6 +522,7 @@ class AxolotlInputConfig(
|
|||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
|
CometConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
@@ -966,6 +980,26 @@ class AxolotlInputConfig(
|
|||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data.get("do_bench_eval") and not (
|
||||||
|
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_test_datasets_bench(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("do_bench_eval")
|
||||||
|
and not data.get("test_datasets")
|
||||||
|
and not data.get("val_set_size")
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||||
|
)
|
||||||
|
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=config_dataset.data_files,
|
filename=config_dataset.data_files,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
elif isinstance(config_dataset.data_files, list):
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
config_dataset=config_dataset,
|
config_dataset=config_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
dataset=ds,
|
|
||||||
d_base_type=d_base_type,
|
d_base_type=d_base_type,
|
||||||
|
dataset=ds,
|
||||||
d_prompt_style=d_prompt_style,
|
d_prompt_style=d_prompt_style,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -267,6 +267,74 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_revision(self):
|
||||||
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_local_hub_with_revision(self):
|
||||||
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||||
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=tmp_ds_path,
|
||||||
|
revision="d05c1cb",
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"ds_type": "parquet",
|
||||||
|
"type": "alpaca",
|
||||||
|
"data_files": [
|
||||||
|
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||||
|
],
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -1329,3 +1330,105 @@ class TestValidationWandb(BaseValidation):
|
|||||||
|
|
||||||
os.environ.pop("WANDB_PROJECT", None)
|
os.environ.pop("WANDB_PROJECT", None)
|
||||||
os.environ.pop("WANDB_DISABLED", None)
|
os.environ.pop("WANDB_DISABLED", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
||||||
|
class TestValidationComet(BaseValidation):
|
||||||
|
"""
|
||||||
|
Validation test for comet
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_comet_sets_env(self, minimal_cfg):
|
||||||
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
|
|
||||||
|
comet_config = {
|
||||||
|
"comet_api_key": "foo",
|
||||||
|
"comet_workspace": "some_workspace",
|
||||||
|
"comet_project_name": "some_project",
|
||||||
|
"comet_experiment_key": "some_experiment_key",
|
||||||
|
"comet_mode": "get_or_create",
|
||||||
|
"comet_online": False,
|
||||||
|
"comet_experiment_config": {
|
||||||
|
"auto_histogram_activation_logging": False,
|
||||||
|
"auto_histogram_epoch_rate": 2,
|
||||||
|
"auto_histogram_gradient_logging": True,
|
||||||
|
"auto_histogram_tensorboard_logging": False,
|
||||||
|
"auto_histogram_weight_logging": True,
|
||||||
|
"auto_log_co2": False,
|
||||||
|
"auto_metric_logging": True,
|
||||||
|
"auto_metric_step_rate": 15,
|
||||||
|
"auto_output_logging": False,
|
||||||
|
"auto_param_logging": True,
|
||||||
|
"comet_disabled": False,
|
||||||
|
"display_summary_level": 2,
|
||||||
|
"distributed_node_identifier": "some_distributed_node_identifier",
|
||||||
|
"log_code": True,
|
||||||
|
"log_env_cpu": False,
|
||||||
|
"log_env_details": True,
|
||||||
|
"log_env_disk": False,
|
||||||
|
"log_env_gpu": True,
|
||||||
|
"log_env_host": False,
|
||||||
|
"log_env_network": True,
|
||||||
|
"log_git_metadata": False,
|
||||||
|
"log_git_patch": True,
|
||||||
|
"log_graph": False,
|
||||||
|
"name": "some_name",
|
||||||
|
"offline_directory": "some_offline_directory",
|
||||||
|
"parse_args": True,
|
||||||
|
"tags": ["tag1", "tag2"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg = DictDefault(comet_config) | minimal_cfg
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
setup_comet_env_vars(new_cfg)
|
||||||
|
|
||||||
|
comet_env = {
|
||||||
|
key: value for key, value in os.environ.items() if key.startswith("COMET_")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(comet_env)
|
||||||
|
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert comet_env == {
|
||||||
|
"COMET_API_KEY": "foo",
|
||||||
|
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
|
||||||
|
"COMET_AUTO_LOG_CO2": "false",
|
||||||
|
"COMET_AUTO_LOG_CODE": "true",
|
||||||
|
"COMET_AUTO_LOG_DISABLE": "false",
|
||||||
|
"COMET_AUTO_LOG_ENV_CPU": "false",
|
||||||
|
"COMET_AUTO_LOG_ENV_DETAILS": "true",
|
||||||
|
"COMET_AUTO_LOG_ENV_DISK": "false",
|
||||||
|
"COMET_AUTO_LOG_ENV_GPU": "true",
|
||||||
|
"COMET_AUTO_LOG_ENV_HOST": "false",
|
||||||
|
"COMET_AUTO_LOG_ENV_NETWORK": "true",
|
||||||
|
"COMET_AUTO_LOG_GIT_METADATA": "false",
|
||||||
|
"COMET_AUTO_LOG_GIT_PATCH": "true",
|
||||||
|
"COMET_AUTO_LOG_GRAPH": "false",
|
||||||
|
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
|
||||||
|
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
|
||||||
|
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
|
||||||
|
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
|
||||||
|
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
|
||||||
|
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
|
||||||
|
"COMET_AUTO_LOG_METRICS": "true",
|
||||||
|
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
|
||||||
|
"COMET_AUTO_LOG_PARAMETERS": "true",
|
||||||
|
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
|
||||||
|
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
|
||||||
|
"COMET_EXPERIMENT_KEY": "some_experiment_key",
|
||||||
|
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
|
||||||
|
"COMET_PROJECT_NAME": "some_project",
|
||||||
|
"COMET_START_EXPERIMENT_NAME": "some_name",
|
||||||
|
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
|
||||||
|
"COMET_START_MODE": "get_or_create",
|
||||||
|
"COMET_START_ONLINE": "false",
|
||||||
|
"COMET_WORKSPACE": "some_workspace",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in comet_env.keys():
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user