Compare commits
2 Commits
feature/en
...
mm3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdd8be7097 | ||
|
|
08143c7b0d |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -28,13 +28,7 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "124"
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.4.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -84,7 +84,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/nightlies.yml
vendored
4
.github/workflows/nightlies.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -83,7 +83,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1"]
|
pytorch_version: ["2.3.1", "2.4.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1"]
|
pytorch_version: ["2.3.1", "2.4.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb,comet_ml
|
known_third_party=wandb
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -14,7 +14,7 @@ Features:
|
|||||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||||
- Easily run with Docker locally or on the cloud
|
- Easily run with Docker locally or on the cloud
|
||||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
- Log results and optionally checkpoints to wandb or mlflow
|
||||||
- And more!
|
- And more!
|
||||||
|
|
||||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||||
@@ -515,22 +515,6 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Comet Logging
|
|
||||||
|
|
||||||
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
|
||||||
|
|
||||||
- wandb options
|
|
||||||
```yaml
|
|
||||||
use_comet:
|
|
||||||
comet_api_key:
|
|
||||||
comet_workspace:
|
|
||||||
comet_project_name:
|
|
||||||
comet_experiment_key:
|
|
||||||
comet_mode:
|
|
||||||
comet_online:
|
|
||||||
comet_experiment_config:
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Special Tokens
|
##### Special Tokens
|
||||||
|
|
||||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||||
|
|||||||
@@ -90,7 +90,6 @@ datasets:
|
|||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
@@ -268,18 +267,6 @@ mlflow_tracking_uri: # URI to mlflow
|
|||||||
mlflow_experiment_name: # Your experiment name
|
mlflow_experiment_name: # Your experiment name
|
||||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||||
|
|
||||||
# Comet configuration if you're using it
|
|
||||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
|
||||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
|
||||||
use_comet: # Enable or disable Comet integration.
|
|
||||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
|
||||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
|
||||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
|
||||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
|
||||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
|
||||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
|
||||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ flash-attn==2.6.3
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.28.post1
|
xformers==0.0.27
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -46,9 +46,3 @@ gcsfs>=2024.5.0
|
|||||||
trl==0.9.6
|
trl==0.9.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
# lm eval harness
|
|
||||||
lm_eval==0.4.4
|
|
||||||
langdetect==1.0.9
|
|
||||||
immutabledict==4.2.0
|
|
||||||
antlr4-python3-runtime==4.13.2
|
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -49,17 +49,10 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 4):
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.27")
|
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 3):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.26.post1")
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.27")
|
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
@@ -55,22 +54,8 @@ LOG = logging.getLogger("axolotl.scripts")
|
|||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
AXOLOTL_LOGO = """
|
|
||||||
#@@ #@@ @@# @@#
|
|
||||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
|
||||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
|
||||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
|
||||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
|
||||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
|
||||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
|
||||||
@@@@ @@@@@@@@@@@@@@@@
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
def print_axolotl_text_art(suffix=None):
|
||||||
def print_legacy_axolotl_text_art(suffix=None):
|
|
||||||
font = "nancyj"
|
font = "nancyj"
|
||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
@@ -83,13 +68,6 @@ def print_legacy_axolotl_text_art(suffix=None):
|
|||||||
print_dep_versions()
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
def print_axolotl_text_art(
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if is_main_process():
|
|
||||||
print(AXOLOTL_LOGO)
|
|
||||||
|
|
||||||
|
|
||||||
def print_dep_versions():
|
def print_dep_versions():
|
||||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
max_len = max(len(pkg) for pkg in packages)
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
@@ -443,8 +421,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
|
|
||||||
setup_comet_env_vars(cfg)
|
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -18,7 +20,6 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.integrations.base import PluginManager
|
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
register_llama3_template,
|
register_llama3_template,
|
||||||
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return do_train(parsed_cfg, parsed_cli_args)
|
return do_train(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg, cli_args) -> None:
|
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length
|
|||||||
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
@@ -1111,12 +1111,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
|
||||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1185,11 +1179,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer, self.tokenizer, "mlflow"
|
trainer, self.tokenizer, "mlflow"
|
||||||
)
|
)
|
||||||
callbacks.append(LogPredictionCallback(self.cfg))
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
|
||||||
LogPredictionCallback = log_prediction_callback_factory(
|
|
||||||
trainer, self.tokenizer, "comet_ml"
|
|
||||||
)
|
|
||||||
callbacks.append(LogPredictionCallback(self.cfg))
|
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||||
@@ -1441,8 +1430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
report_to.append("mlflow")
|
report_to.append("mlflow")
|
||||||
if self.cfg.use_tensorboard:
|
if self.cfg.use_tensorboard:
|
||||||
report_to.append("tensorboard")
|
report_to.append("tensorboard")
|
||||||
if self.cfg.use_comet:
|
|
||||||
report_to.append("comet_ml")
|
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
|
|||||||
@@ -159,29 +159,6 @@ class BasePlugin:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_train(self, cfg, model):
|
|
||||||
"""
|
|
||||||
Performs actions after training is complete.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
cfg (dict): The axolotl configuration
|
|
||||||
model (object): The loaded model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
|
||||||
"""
|
|
||||||
Performs actions after training is complete and the model is unloaded.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
cfg (dict): The configuration for the plugin.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def load_plugin(plugin_name: str) -> BasePlugin:
|
def load_plugin(plugin_name: str) -> BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -404,17 +381,3 @@ class PluginManager:
|
|||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
|
||||||
"""
|
|
||||||
Calls the post_train_unload method of all registered plugins.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
cfg (dict): The configuration for the plugins.
|
|
||||||
model (object): The loaded model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
for plugin in self.plugins:
|
|
||||||
plugin.post_train_unload(cfg)
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
# LM Eval Harness
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.lm_eval.LMEvalPlugin
|
|
||||||
|
|
||||||
lm_eval_tasks:
|
|
||||||
- gsm8k
|
|
||||||
- hellaswag
|
|
||||||
- arc_easy
|
|
||||||
```
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
Module for the Plugin for LM Eval Harness
|
|
||||||
"""
|
|
||||||
import subprocess # nosec
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
|
||||||
|
|
||||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
class LMEvalPlugin(BasePlugin):
|
|
||||||
"""
|
|
||||||
Plugin for LM Evaluation Harness integraton with Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self):
|
|
||||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
|
||||||
tasks = ",".join(cfg.lm_eval_tasks)
|
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
|
||||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
|
||||||
output_path = cfg.output_dir
|
|
||||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
|
||||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
subprocess.run( # nosec
|
|
||||||
[
|
|
||||||
"lm_eval",
|
|
||||||
"--model",
|
|
||||||
"hf",
|
|
||||||
"--model_args",
|
|
||||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
|
||||||
"--tasks",
|
|
||||||
tasks,
|
|
||||||
"--batch_size",
|
|
||||||
str(cfg.lm_eval_batch_size),
|
|
||||||
"--output_path",
|
|
||||||
output_path,
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
"""
|
|
||||||
Module for handling lm eval harness input arguments.
|
|
||||||
"""
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class LMEvalArgs(BaseModel):
|
|
||||||
"""
|
|
||||||
Input args for lm eval harness
|
|
||||||
"""
|
|
||||||
|
|
||||||
lm_eval_tasks: List[str] = []
|
|
||||||
lm_eval_batch_size: Optional[int] = 8
|
|
||||||
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
|||||||
def reset_optimizer(
|
def reset_optimizer(
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
*,
|
*,
|
||||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||||
optimizer_state_keys: List[str],
|
optimizer_state_keys: list[str],
|
||||||
prune_ratio: float = 0.9,
|
prune_ratio: float = 0.9,
|
||||||
):
|
):
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Basic utils for Axolotl
|
Basic utils for Axolotl
|
||||||
"""
|
"""
|
||||||
import importlib.util
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
return importlib.util.find_spec("mlflow") is not None
|
return importlib.util.find_spec("mlflow") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_comet_available():
|
|
||||||
return importlib.util.find_spec("comet_ml") is not None
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
references=[[r] for r in references],
|
references=[[r] for r in references],
|
||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
)
|
)
|
||||||
scores["eval_" + metric_name] = score
|
scores[metric_name] = score
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def predict_with_generate():
|
def predict_with_generate():
|
||||||
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
|||||||
artifact_file="PredictionsVsGroundTruth.json",
|
artifact_file="PredictionsVsGroundTruth.json",
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
)
|
)
|
||||||
elif logger == "comet_ml" and is_comet_available():
|
|
||||||
import comet_ml
|
|
||||||
|
|
||||||
experiment = comet_ml.get_running_experiment()
|
|
||||||
if experiment:
|
|
||||||
experiment.log_table(
|
|
||||||
f"{name} - Predictions vs Ground Truth.csv",
|
|
||||||
pd.DataFrame(table_data),
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
log_table_from_dataloader("Eval", eval_dataloader)
|
log_table_from_dataloader("Eval", eval_dataloader)
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
"""Comet module for trainer callbacks"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import comet_ml
|
|
||||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
|
||||||
"""Callback to save axolotl config to comet"""
|
|
||||||
|
|
||||||
def __init__(self, axolotl_config_path):
|
|
||||||
self.axolotl_config_path = axolotl_config_path
|
|
||||||
|
|
||||||
def on_train_begin(
|
|
||||||
self,
|
|
||||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if is_main_process():
|
|
||||||
try:
|
|
||||||
comet_experiment = comet_ml.start(source="axolotl")
|
|
||||||
comet_experiment.log_other("Created from", "axolotl")
|
|
||||||
comet_experiment.log_asset(
|
|
||||||
self.axolotl_config_path,
|
|
||||||
file_name="axolotl-config",
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
"The Axolotl config has been saved to the Comet Experiment under assets."
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
|
||||||
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
|
||||||
return control
|
|
||||||
@@ -5,9 +5,7 @@ These templates are used for formatting messages in a conversation.
|
|||||||
|
|
||||||
CHAT_TEMPLATES = {
|
CHAT_TEMPLATES = {
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
|
||||||
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
packing: bool = False
|
packing: bool = False
|
||||||
|
sequence_length: Optional[int] = None
|
||||||
max_images: int = -1
|
max_images: int = -1
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
@@ -32,11 +33,112 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
if self.packing:
|
||||||
|
return self.__class__.process_rows_packing(
|
||||||
|
examples,
|
||||||
|
self.processor,
|
||||||
|
self.chat_template,
|
||||||
|
self.max_images,
|
||||||
|
self.sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
examples, self.processor, self.chat_template, self.max_images
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows_packing(
|
||||||
|
examples,
|
||||||
|
processor,
|
||||||
|
chat_template,
|
||||||
|
max_images,
|
||||||
|
sequence_length,
|
||||||
|
length_only=False,
|
||||||
|
):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Perform sample packing within a batch
|
||||||
|
|
||||||
|
if processor.tokenizer.sep_token is None:
|
||||||
|
sep_token = "[SEP]"
|
||||||
|
processor.tokenizer.add_tokens([sep_token])
|
||||||
|
processor.tokenizer.sep_token = sep_token
|
||||||
|
sep_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.tokenizer.sep_token
|
||||||
|
)
|
||||||
|
pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
images = [example["images"] for example in examples]
|
||||||
|
|
||||||
|
if max_images > 0:
|
||||||
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|
||||||
|
batch = processor(text=texts, images=images, padding=False)
|
||||||
|
|
||||||
|
n_sequence = len(examples)
|
||||||
|
n_seq_in_batch = 0
|
||||||
|
pack_len = 0
|
||||||
|
features_pack = {}
|
||||||
|
packed = {}
|
||||||
|
features = list[batch.keys()]
|
||||||
|
for feature in features:
|
||||||
|
features_pack[feature] = []
|
||||||
|
packed[feature] = []
|
||||||
|
features.remove("input_ids")
|
||||||
|
|
||||||
|
for seq_in_batch_id in range(n_sequence):
|
||||||
|
next_seq_len = len(batch["input_ids"][seq_in_batch_id])
|
||||||
|
if not pack_len + next_seq_len + 1 < sequence_length:
|
||||||
|
n_seq_in_batch += 1
|
||||||
|
pack_len += next_seq_len + 1
|
||||||
|
features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
|
||||||
|
sep_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
Do something with attention mask and cross-attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
features_pack[feature] += batch[feature][seq_in_batch_id]
|
||||||
|
|
||||||
|
else:
|
||||||
|
for _ in range(sequence_length - pack_len):
|
||||||
|
features_pack["input_ids"] += [pad_token_id]
|
||||||
|
|
||||||
|
packed["input_ids"].append(
|
||||||
|
torch.tensor(features_pack["input_ids"].copy())
|
||||||
|
)
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
||||||
|
features_pack[feature] = []
|
||||||
|
|
||||||
|
pack_len = 0
|
||||||
|
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
|
labels = [pack.clone() for pack in packed["input_ids"]]
|
||||||
|
for label_id, label in enumerate(labels):
|
||||||
|
labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
|
||||||
|
labels[label_id][label == image_token_id] = -100
|
||||||
|
packed["labels"] = labels
|
||||||
|
|
||||||
|
if length_only:
|
||||||
|
return {
|
||||||
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
|
return packed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
"""Module for wandb utilities"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
|
||||||
|
|
||||||
COMET_ENV_MAPPING_OVERRIDE = {
|
|
||||||
"comet_mode": "COMET_START_MODE",
|
|
||||||
"comet_online": "COMET_START_ONLINE",
|
|
||||||
}
|
|
||||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
|
||||||
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
|
||||||
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
|
||||||
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
|
||||||
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
|
||||||
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
|
||||||
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
|
||||||
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
|
||||||
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
|
||||||
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
|
||||||
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
|
||||||
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
|
||||||
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
|
||||||
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
|
||||||
"log_code": "COMET_AUTO_LOG_CODE",
|
|
||||||
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
|
||||||
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
|
||||||
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
|
||||||
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
|
||||||
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
|
||||||
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
|
||||||
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
|
||||||
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
|
||||||
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
|
||||||
"name": "COMET_START_EXPERIMENT_NAME",
|
|
||||||
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
|
||||||
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
|
||||||
"tags": "COMET_START_EXPERIMENT_TAGS",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def python_value_to_environ_value(python_value):
|
|
||||||
if isinstance(python_value, bool):
|
|
||||||
if python_value is True:
|
|
||||||
return "true"
|
|
||||||
|
|
||||||
return "false"
|
|
||||||
|
|
||||||
if isinstance(python_value, int):
|
|
||||||
return str(python_value)
|
|
||||||
|
|
||||||
if isinstance(python_value, list): # Comet only have one list of string parameter
|
|
||||||
return ",".join(map(str, python_value))
|
|
||||||
|
|
||||||
return python_value
|
|
||||||
|
|
||||||
|
|
||||||
def setup_comet_env_vars(cfg: DictDefault):
|
|
||||||
# TODO, we need to convert Axolotl configuration to environment variables
|
|
||||||
# as Transformers integration are call first and would create an
|
|
||||||
# Experiment first
|
|
||||||
|
|
||||||
for key in cfg.keys():
|
|
||||||
if key.startswith("comet_") and key != "comet_experiment_config":
|
|
||||||
value = cfg.get(key, "")
|
|
||||||
|
|
||||||
if value is not None and value != "":
|
|
||||||
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
|
||||||
final_value = python_value_to_environ_value(value)
|
|
||||||
os.environ[env_variable_name] = final_value
|
|
||||||
|
|
||||||
if cfg.comet_experiment_config:
|
|
||||||
for key, value in cfg.comet_experiment_config.items():
|
|
||||||
if value is not None and value != "":
|
|
||||||
config_env_variable_name = (
|
|
||||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
|
||||||
)
|
|
||||||
|
|
||||||
if config_env_variable_name is None:
|
|
||||||
LOG.warning(
|
|
||||||
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
final_value = python_value_to_environ_value(value)
|
|
||||||
os.environ[config_env_variable_name] = final_value
|
|
||||||
|
|
||||||
# Enable comet if project name is present
|
|
||||||
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
|
||||||
cfg.use_comet = True
|
|
||||||
@@ -489,19 +489,6 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CometConfig(BaseModel):
|
|
||||||
"""Comet configuration subset"""
|
|
||||||
|
|
||||||
use_comet: Optional[bool] = None
|
|
||||||
comet_api_key: Optional[str] = None
|
|
||||||
comet_workspace: Optional[str] = None
|
|
||||||
comet_project_name: Optional[str] = None
|
|
||||||
comet_experiment_key: Optional[str] = None
|
|
||||||
comet_mode: Optional[str] = None
|
|
||||||
comet_online: Optional[bool] = None
|
|
||||||
comet_experiment_config: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GradioConfig(BaseModel):
|
class GradioConfig(BaseModel):
|
||||||
"""Gradio configuration subset"""
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
@@ -522,7 +509,6 @@ class AxolotlInputConfig(
|
|||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
@@ -980,26 +966,6 @@ class AxolotlInputConfig(
|
|||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.get("do_bench_eval") and not (
|
|
||||||
data.get("evals_per_epoch") or data.get("eval_steps")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_test_datasets_bench(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("do_bench_eval")
|
|
||||||
and not data.get("test_datasets")
|
|
||||||
and not data.get("val_set_size")
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
|
||||||
)
|
|
||||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -242,7 +242,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -347,7 +346,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -382,7 +380,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=config_dataset.data_files,
|
filename=config_dataset.data_files,
|
||||||
revision=config_dataset.revision,
|
|
||||||
)
|
)
|
||||||
elif isinstance(config_dataset.data_files, list):
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
@@ -392,7 +389,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
revision=config_dataset.revision,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -437,8 +433,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
config_dataset=config_dataset,
|
config_dataset=config_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
d_base_type=d_base_type,
|
|
||||||
dataset=ds,
|
dataset=ds,
|
||||||
|
d_base_type=d_base_type,
|
||||||
d_prompt_style=d_prompt_style,
|
d_prompt_style=d_prompt_style,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -267,74 +267,6 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_hub_with_revision(self):
|
|
||||||
"""Verify that processing data from the hub works with a specific revision"""
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
"revision": "d05c1cb",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
|
||||||
assert "input_ids" in dataset.features
|
|
||||||
assert "attention_mask" in dataset.features
|
|
||||||
assert "labels" in dataset.features
|
|
||||||
|
|
||||||
def test_load_local_hub_with_revision(self):
|
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=tmp_ds_path,
|
|
||||||
revision="d05c1cb",
|
|
||||||
)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"ds_type": "parquet",
|
|
||||||
"type": "alpaca",
|
|
||||||
"data_files": [
|
|
||||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
|
||||||
],
|
|
||||||
"revision": "d05c1cb",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
|
||||||
assert "input_ids" in dataset.features
|
|
||||||
assert "attention_mask" in dataset.features
|
|
||||||
assert "labels" in dataset.features
|
|
||||||
shutil.rmtree(tmp_ds_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -1330,105 +1329,3 @@ class TestValidationWandb(BaseValidation):
|
|||||||
|
|
||||||
os.environ.pop("WANDB_PROJECT", None)
|
os.environ.pop("WANDB_PROJECT", None)
|
||||||
os.environ.pop("WANDB_DISABLED", None)
|
os.environ.pop("WANDB_DISABLED", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
|
||||||
class TestValidationComet(BaseValidation):
|
|
||||||
"""
|
|
||||||
Validation test for comet
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_comet_sets_env(self, minimal_cfg):
|
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
|
||||||
|
|
||||||
comet_config = {
|
|
||||||
"comet_api_key": "foo",
|
|
||||||
"comet_workspace": "some_workspace",
|
|
||||||
"comet_project_name": "some_project",
|
|
||||||
"comet_experiment_key": "some_experiment_key",
|
|
||||||
"comet_mode": "get_or_create",
|
|
||||||
"comet_online": False,
|
|
||||||
"comet_experiment_config": {
|
|
||||||
"auto_histogram_activation_logging": False,
|
|
||||||
"auto_histogram_epoch_rate": 2,
|
|
||||||
"auto_histogram_gradient_logging": True,
|
|
||||||
"auto_histogram_tensorboard_logging": False,
|
|
||||||
"auto_histogram_weight_logging": True,
|
|
||||||
"auto_log_co2": False,
|
|
||||||
"auto_metric_logging": True,
|
|
||||||
"auto_metric_step_rate": 15,
|
|
||||||
"auto_output_logging": False,
|
|
||||||
"auto_param_logging": True,
|
|
||||||
"comet_disabled": False,
|
|
||||||
"display_summary_level": 2,
|
|
||||||
"distributed_node_identifier": "some_distributed_node_identifier",
|
|
||||||
"log_code": True,
|
|
||||||
"log_env_cpu": False,
|
|
||||||
"log_env_details": True,
|
|
||||||
"log_env_disk": False,
|
|
||||||
"log_env_gpu": True,
|
|
||||||
"log_env_host": False,
|
|
||||||
"log_env_network": True,
|
|
||||||
"log_git_metadata": False,
|
|
||||||
"log_git_patch": True,
|
|
||||||
"log_graph": False,
|
|
||||||
"name": "some_name",
|
|
||||||
"offline_directory": "some_offline_directory",
|
|
||||||
"parse_args": True,
|
|
||||||
"tags": ["tag1", "tag2"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg = DictDefault(comet_config) | minimal_cfg
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
|
|
||||||
setup_comet_env_vars(new_cfg)
|
|
||||||
|
|
||||||
comet_env = {
|
|
||||||
key: value for key, value in os.environ.items() if key.startswith("COMET_")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert (
|
|
||||||
len(comet_env)
|
|
||||||
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
assert comet_env == {
|
|
||||||
"COMET_API_KEY": "foo",
|
|
||||||
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
|
|
||||||
"COMET_AUTO_LOG_CO2": "false",
|
|
||||||
"COMET_AUTO_LOG_CODE": "true",
|
|
||||||
"COMET_AUTO_LOG_DISABLE": "false",
|
|
||||||
"COMET_AUTO_LOG_ENV_CPU": "false",
|
|
||||||
"COMET_AUTO_LOG_ENV_DETAILS": "true",
|
|
||||||
"COMET_AUTO_LOG_ENV_DISK": "false",
|
|
||||||
"COMET_AUTO_LOG_ENV_GPU": "true",
|
|
||||||
"COMET_AUTO_LOG_ENV_HOST": "false",
|
|
||||||
"COMET_AUTO_LOG_ENV_NETWORK": "true",
|
|
||||||
"COMET_AUTO_LOG_GIT_METADATA": "false",
|
|
||||||
"COMET_AUTO_LOG_GIT_PATCH": "true",
|
|
||||||
"COMET_AUTO_LOG_GRAPH": "false",
|
|
||||||
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
|
|
||||||
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
|
|
||||||
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
|
|
||||||
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
|
|
||||||
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
|
|
||||||
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
|
|
||||||
"COMET_AUTO_LOG_METRICS": "true",
|
|
||||||
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
|
|
||||||
"COMET_AUTO_LOG_PARAMETERS": "true",
|
|
||||||
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
|
|
||||||
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
|
|
||||||
"COMET_EXPERIMENT_KEY": "some_experiment_key",
|
|
||||||
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
|
|
||||||
"COMET_PROJECT_NAME": "some_project",
|
|
||||||
"COMET_START_EXPERIMENT_NAME": "some_name",
|
|
||||||
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
|
|
||||||
"COMET_START_MODE": "get_or_create",
|
|
||||||
"COMET_START_ONLINE": "false",
|
|
||||||
"COMET_WORKSPACE": "some_workspace",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key in comet_env.keys():
|
|
||||||
os.environ.pop(key, None)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user