Compare commits
4 Commits
v0.8.1
...
pre-commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
156fede4f7 | ||
|
|
dcbbd7af79 | ||
|
|
21bac7ce1a | ||
|
|
aaa4571826 |
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
name: Pre-commit auto-update
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 0' # Run weekly
|
||||||
|
workflow_dispatch: # Manual kickoff
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
auto-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Update pre-commit hooks
|
||||||
|
id: update
|
||||||
|
run: |
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit autoupdate
|
||||||
|
if [[ -n $(git status --porcelain) ]]; then
|
||||||
|
echo "changes=true" >> $GITHUB_OUTPUT
|
||||||
|
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.update.outputs.changes == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v6
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
branch: update/pre-commit-hooks
|
||||||
|
delete-branch: true
|
||||||
|
title: "chore: update pre-commit hooks"
|
||||||
|
commit-message: "chore: update pre-commit hooks"
|
||||||
|
body: |
|
||||||
|
Automated PR to update pre-commit hooks to their latest versions.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Changes:</summary>
|
||||||
|
|
||||||
|
```diff
|
||||||
|
${{ steps.update.outputs.diff }}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -11,23 +11,23 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 23.3.0
|
rev: 25.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.12.0
|
rev: 6.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 6.1.0
|
rev: 7.1.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: c8c96d20cde3552a79858c7456bb1483bf83d633
|
rev: v3.3.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.3.0
|
rev: v1.15.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.7.5
|
rev: 1.8.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
helper script to parse chat datasets into a usable yaml
|
helper script to parse chat datasets into a usable yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -17,12 +18,12 @@ if v < V("2.4.0"):
|
|||||||
|
|
||||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
|
|
||||||
uninstall_prefix = ""
|
UNINSTALL_PREFIX = ""
|
||||||
if cce_spec:
|
if cce_spec:
|
||||||
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
uninstall_prefix = "pip uninstall -y cut-cross-entropy && "
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
print(
|
print(
|
||||||
uninstall_prefix
|
UNINSTALL_PREFIX
|
||||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
launch axolotl in supported cloud platforms
|
launch axolotl in supported cloud platforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
base class for cloud platforms from cli
|
base class for cloud platforms from cli
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Modal Cloud support from CLI
|
Modal Cloud support from CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Click CLI definitions for various axolotl commands."""
|
"""Click CLI definitions for various axolotl commands."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import dataclasses
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
@@ -24,7 +23,7 @@ configure_logging()
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
def strip_optional_type(field_type: type | str | None):
|
||||||
"""
|
"""
|
||||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
ChatML transformation functions for MessageContents
|
ChatML transformation functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Llama 3.x chat formatting functions for MessageContents
|
Llama 3.x chat formatting functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
shared functions for format transforms
|
shared functions for format transforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.core.chat.messages import MessageContents, Messages
|
from axolotl.core.chat.messages import MessageContents, Messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
internal message representations of chat messages
|
internal message representations of chat messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Mapping, Union
|
from typing import Any, Mapping, Union
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = {}
|
training_arguments_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.include_tokens_per_second is not None:
|
if self.cfg.include_tokens_per_second is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["include_tokens_per_second"] = (
|
||||||
"include_tokens_per_second"
|
self.cfg.include_tokens_per_second
|
||||||
] = self.cfg.include_tokens_per_second
|
)
|
||||||
|
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_arguments_kwargs["bf16_full_eval"] = True
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_checkpointing"] = (
|
||||||
"gradient_checkpointing"
|
self.cfg.gradient_checkpointing
|
||||||
] = self.cfg.gradient_checkpointing
|
)
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
||||||
"gradient_checkpointing_kwargs"
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
)
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||||
|
|
||||||
if self.cfg.lr_quadratic_warmup is not None:
|
if self.cfg.lr_quadratic_warmup is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
||||||
"lr_quadratic_warmup"
|
self.cfg.lr_quadratic_warmup
|
||||||
] = self.cfg.lr_quadratic_warmup
|
)
|
||||||
|
|
||||||
if self.cfg.adam_beta1:
|
if self.cfg.adam_beta1:
|
||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_pin_memory"] = (
|
||||||
"dataloader_pin_memory"
|
self.cfg.dataloader_pin_memory
|
||||||
] = self.cfg.dataloader_pin_memory
|
)
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_num_workers"] = (
|
||||||
"dataloader_num_workers"
|
self.cfg.dataloader_num_workers
|
||||||
] = self.cfg.dataloader_num_workers
|
)
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
||||||
"dataloader_prefetch_factor"
|
self.cfg.dataloader_prefetch_factor
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
)
|
||||||
if self.cfg.dataloader_drop_last is not None:
|
if self.cfg.dataloader_drop_last is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||||
"dataloader_drop_last"
|
self.cfg.dataloader_drop_last
|
||||||
] = self.cfg.dataloader_drop_last
|
)
|
||||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||||
|
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["remove_unused_columns"] = (
|
||||||
"remove_unused_columns"
|
self.cfg.remove_unused_columns
|
||||||
] = self.cfg.remove_unused_columns
|
)
|
||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.do_causal_lm_eval:
|
if self.cfg.do_causal_lm_eval:
|
||||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||||
if self.cfg.metric_for_best_model:
|
if self.cfg.metric_for_best_model:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["metric_for_best_model"] = (
|
||||||
"metric_for_best_model"
|
self.cfg.metric_for_best_model
|
||||||
] = self.cfg.metric_for_best_model
|
)
|
||||||
if self.cfg.greater_is_better:
|
if self.cfg.greater_is_better:
|
||||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||||
|
|
||||||
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||||
if self.cfg.torch_compile_backend:
|
if self.cfg.torch_compile_backend:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["torch_compile_backend"] = (
|
||||||
"torch_compile_backend"
|
self.cfg.torch_compile_backend
|
||||||
] = self.cfg.torch_compile_backend
|
)
|
||||||
if self.cfg.torch_compile_mode:
|
if self.cfg.torch_compile_mode:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["torch_compile_mode"] = (
|
||||||
"torch_compile_mode"
|
self.cfg.torch_compile_mode
|
||||||
] = self.cfg.torch_compile_mode
|
)
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.ddp_bucket_cap_mb:
|
if self.cfg.ddp_bucket_cap_mb:
|
||||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||||
if self.cfg.ddp_broadcast_buffers is not None:
|
if self.cfg.ddp_broadcast_buffers is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
||||||
"ddp_broadcast_buffers"
|
self.cfg.ddp_broadcast_buffers
|
||||||
] = self.cfg.ddp_broadcast_buffers
|
)
|
||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
total_num_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||||
"per_device_train_batch_size"
|
self.cfg.micro_batch_size
|
||||||
] = self.cfg.micro_batch_size
|
)
|
||||||
if self.cfg.eval_batch_size:
|
if self.cfg.eval_batch_size:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
||||||
"per_device_eval_batch_size"
|
self.cfg.eval_batch_size
|
||||||
] = self.cfg.eval_batch_size
|
)
|
||||||
if self.cfg.auto_find_batch_size is not None:
|
if self.cfg.auto_find_batch_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||||
"auto_find_batch_size"
|
self.cfg.auto_find_batch_size
|
||||||
] = self.cfg.auto_find_batch_size
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
||||||
"gradient_accumulation_steps"
|
self.cfg.gradient_accumulation_steps
|
||||||
] = self.cfg.gradient_accumulation_steps
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||||
"eval_accumulation_steps"
|
self.cfg.gradient_accumulation_steps
|
||||||
] = self.cfg.gradient_accumulation_steps
|
)
|
||||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||||
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
||||||
"alternate_lr_scheduler_type"
|
self.cfg.lr_scheduler
|
||||||
] = self.cfg.lr_scheduler
|
)
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
||||||
"cosine_constant_lr_ratio"
|
self.cfg.cosine_constant_lr_ratio
|
||||||
] = self.cfg.cosine_constant_lr_ratio
|
)
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
"sample_packing_bin_size"
|
self.cfg.sample_packing_bin_size
|
||||||
] = self.cfg.sample_packing_bin_size
|
)
|
||||||
if self.cfg.sample_packing_group_size is not None:
|
if self.cfg.sample_packing_group_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_group_size"] = (
|
||||||
"sample_packing_group_size"
|
self.cfg.sample_packing_group_size
|
||||||
] = self.cfg.sample_packing_group_size
|
)
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_efficiency"] = (
|
||||||
"sample_packing_efficiency"
|
self.cfg.sample_packing_eff_est
|
||||||
] = self.cfg.sample_packing_eff_est
|
)
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||||
"relora_warmup_steps"
|
self.cfg.relora_warmup_steps
|
||||||
] = self.cfg.relora_warmup_steps
|
)
|
||||||
if self.cfg.relora_anneal_steps:
|
if self.cfg.relora_anneal_steps:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||||
"relora_anneal_steps"
|
self.cfg.relora_anneal_steps
|
||||||
] = self.cfg.relora_anneal_steps
|
)
|
||||||
if self.cfg.relora_prune_ratio:
|
if self.cfg.relora_prune_ratio:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||||
"relora_prune_ratio"
|
self.cfg.relora_prune_ratio
|
||||||
] = self.cfg.relora_prune_ratio
|
)
|
||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lisa_step_interval"] = (
|
||||||
"lisa_step_interval"
|
self.cfg.lisa_step_interval
|
||||||
] = self.cfg.lisa_step_interval
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lisa_layers_attribute"] = (
|
||||||
"lisa_layers_attribute"
|
self.cfg.lisa_layers_attribute
|
||||||
] = self.cfg.lisa_layers_attribute
|
)
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["neftune_noise_alpha"] = (
|
||||||
"neftune_noise_alpha"
|
self.cfg.neftune_noise_alpha
|
||||||
] = self.cfg.neftune_noise_alpha
|
)
|
||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["optim_target_modules"] = (
|
||||||
"optim_target_modules"
|
self.cfg.optim_target_modules
|
||||||
] = self.cfg.optim_target_modules
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["loraplus_lr_embedding"] = (
|
||||||
"loraplus_lr_embedding"
|
self.cfg.loraplus_lr_embedding
|
||||||
] = self.cfg.loraplus_lr_embedding
|
)
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["accelerator_config"] = (
|
||||||
"accelerator_config"
|
self.cfg.accelerator_config
|
||||||
] = self.cfg.accelerator_config
|
)
|
||||||
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||||
"kd_zscore_base_temp"
|
self.cfg.kd_zscore_base_temp
|
||||||
] = self.cfg.kd_zscore_base_temp
|
)
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
if self.cfg.kd_top_k_before_softmax is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||||
"kd_top_k_before_softmax"
|
self.cfg.kd_top_k_before_softmax
|
||||||
] = self.cfg.kd_top_k_before_softmax
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["remove_unused_columns"] = (
|
||||||
"remove_unused_columns"
|
self.cfg.remove_unused_columns
|
||||||
] = self.cfg.remove_unused_columns
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["remove_unused_columns"] = False
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_pin_memory"] = (
|
||||||
"dataloader_pin_memory"
|
self.cfg.dataloader_pin_memory
|
||||||
] = self.cfg.dataloader_pin_memory
|
)
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_num_workers"] = (
|
||||||
"dataloader_num_workers"
|
self.cfg.dataloader_num_workers
|
||||||
] = self.cfg.dataloader_num_workers
|
)
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||||
"dataloader_prefetch_factor"
|
self.cfg.dataloader_prefetch_factor
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
)
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs[
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
"gradient_checkpointing"
|
self.cfg.gradient_checkpointing
|
||||||
] = self.cfg.gradient_checkpointing
|
)
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||||
"gradient_checkpointing_kwargs"
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
"use_reentrant": False
|
"use_reentrant": False
|
||||||
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["use_logits_to_keep"] = (
|
||||||
"use_logits_to_keep"
|
self.cfg.dpo_use_logits_to_keep
|
||||||
] = self.cfg.dpo_use_logits_to_keep
|
)
|
||||||
|
|
||||||
for blocklist_key in blocklist_args_kwargs:
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
"precompute_ref_log_probs"
|
self.cfg.precompute_ref_log_probs
|
||||||
] = self.cfg.precompute_ref_log_probs
|
)
|
||||||
if self.cfg.rl == "grpo":
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params[
|
dataloader_params["prefetch_factor"] = (
|
||||||
"prefetch_factor"
|
self.args.dataloader_prefetch_factor
|
||||||
] = self.args.dataloader_prefetch_factor
|
)
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
sampler = self._get_train_sampler()
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params[
|
dataloader_params["prefetch_factor"] = (
|
||||||
"prefetch_factor"
|
self.args.dataloader_prefetch_factor
|
||||||
] = self.args.dataloader_prefetch_factor
|
)
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
DPO Specific Strategy for training
|
DPO Specific Strategy for training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl specific DPO args
|
Axolotl specific DPO args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
DPO trainer for axolotl
|
DPO trainer for axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ class GRPOStrategy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if trl.vllm_gpu_memory_utilization:
|
if trl.vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||||
"vllm_gpu_memory_utilization"
|
trl.vllm_gpu_memory_utilization
|
||||||
] = trl.vllm_gpu_memory_utilization
|
)
|
||||||
|
|
||||||
if trl.vllm_max_model_len:
|
if trl.vllm_max_model_len:
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||||
@@ -86,9 +86,9 @@ class GRPOStrategy:
|
|||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs["reward_processing_classes"] = (
|
||||||
"reward_processing_classes"
|
cfg.trl.reward_processing_classes
|
||||||
] = cfg.trl.reward_processing_classes
|
)
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl Specific Training Args
|
Axolotl Specific Training Args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import GRPOConfig
|
from trl import GRPOConfig
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl GRPO trainer
|
Axolotl GRPO trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from accelerate.utils import is_peft_model
|
from accelerate.utils import is_peft_model
|
||||||
from accelerate.utils.other import is_compiled_module
|
from accelerate.utils.other import is_compiled_module
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
module for TRL PPO training
|
module for TRL PPO training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
extra axolotl specific training args
|
extra axolotl specific training args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Grokfast plugin for Axolotl
|
Grokfast plugin for Axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
config args for grokfast plugin
|
config args for grokfast plugin
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||||
kd_ce_alpha: Optional[
|
kd_ce_alpha: Optional[float] = (
|
||||||
float
|
None # loss coefficient for cross-entropy loss during KD
|
||||||
] = None # loss coefficient for cross-entropy loss during KD
|
)
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[
|
kd_top_k_before_softmax: Optional[bool] = (
|
||||||
bool
|
None # whether to sample top k before softmax during KD
|
||||||
] = None # whether to sample top k before softmax during KD
|
)
|
||||||
|
|||||||
@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if "cross_entropy" in liger_fn_sig.parameters:
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
kwargs[
|
kwargs["fused_linear_cross_entropy"] = (
|
||||||
"fused_linear_cross_entropy"
|
cfg.liger_fused_linear_cross_entropy
|
||||||
] = cfg.liger_fused_linear_cross_entropy
|
)
|
||||||
if "rms_norm" in liger_fn_sig.parameters:
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
if "layer_norm" in liger_fn_sig.parameters:
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Jamba model with LigerFusedLinearCrossEntropyLoss
|
Jamba model with LigerFusedLinearCrossEntropyLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Module for the Plugin for LM Eval Harness
|
Module for the Plugin for LM Eval Harness
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling lm eval harness input arguments.
|
Module for handling lm eval harness input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
axolotl CLI for running lm_eval tasks
|
axolotl CLI for running lm_eval tasks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||||
|
|
||||||
# pylint: disable=invalid-name,global-statement
|
# pylint: disable=invalid-name,global-statement
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
HF Transformers MambaConfig
|
HF Transformers MambaConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Monkeypatch for Vision Llama for FA2 support
|
Monkeypatch for Vision Llama for FA2 support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -220,10 +221,10 @@ def patch_mllama():
|
|||||||
True
|
True
|
||||||
)
|
)
|
||||||
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
||||||
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
|
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
|
||||||
"flash_attention_2"
|
MllamaTextCrossFlashAttention2
|
||||||
] = MllamaTextCrossFlashAttention2
|
)
|
||||||
# fallback to SDPA
|
# fallback to SDPA
|
||||||
MLLAMA_VISION_ATTENTION_CLASSES[
|
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
|
||||||
"flash_attention_2"
|
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
||||||
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
)
|
)
|
||||||
@@ -490,9 +492,11 @@ def flashattn_forward(
|
|||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
query_padding_mask=(
|
||||||
if attention_mask is not None
|
attention_mask[:, -query_states.size(1) :]
|
||||||
else None,
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
@@ -531,9 +535,11 @@ def flashattn_forward(
|
|||||||
value_states,
|
value_states,
|
||||||
kvpacked=True,
|
kvpacked=True,
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
query_padding_mask=(
|
||||||
if attention_mask is not None
|
attention_mask[:, -query_states.size(1) :]
|
||||||
else None,
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Flash attention monkey patch for mistral model"""
|
"""Flash attention monkey patch for mistral model"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
|
|||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
@@ -243,9 +247,11 @@ def flashattn_forward(
|
|||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
query_padding_mask=(
|
||||||
if attention_mask is not None
|
attention_mask[:, -query_states.size(1) :]
|
||||||
else None,
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
@@ -286,9 +292,11 @@ def flashattn_forward(
|
|||||||
value_states,
|
value_states,
|
||||||
kvpacked=True,
|
kvpacked=True,
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
query_padding_mask=(
|
||||||
if attention_mask is not None
|
attention_mask[:, -query_states.size(1) :]
|
||||||
else None,
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Patches to support multipack for mixtral
|
Patches to support multipack for mixtral
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -411,7 +412,10 @@ def merge_and_save(
|
|||||||
if shard_path.endswith(".safetensors"):
|
if shard_path.endswith(".safetensors"):
|
||||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||||
else:
|
else:
|
||||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
in_tensors = torch.load(
|
||||||
|
Path(model_src) / shard_path,
|
||||||
|
weights_only=True, # to prevent arbitrary code execution
|
||||||
|
)
|
||||||
if "state_dict" in in_tensors:
|
if "state_dict" in in_tensors:
|
||||||
in_tensors = in_tensors["state_dict"]
|
in_tensors = in_tensors["state_dict"]
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
""" PyTorch StableLM Epoch model. """
|
"""PyTorch StableLM Epoch model."""
|
||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
fix for FSDP optimizer save in trainer w 4.47.0
|
fix for FSDP optimizer save in trainer w 4.47.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Shared utils for the monkeypatches
|
Shared utils for the monkeypatches
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
from xformers.ops import SwiGLU
|
from xformers.ops import SwiGLU
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Generator, Tuple, Union
|
from typing import Generator, Tuple, Union
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Basic completion text
|
Basic completion text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Generator, Optional, Tuple
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
module for DPO style dataset transform strategies
|
module for DPO style dataset transform strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ..base import load as load_base
|
from ..base import load as load_base
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ def default(
|
|||||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -52,9 +52,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -78,9 +78,9 @@ def icr(
|
|||||||
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -120,9 +120,9 @@ def prompt_pairs(
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
@@ -34,9 +34,9 @@ def default(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -53,9 +53,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -79,9 +79,9 @@ def icr(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -121,9 +121,9 @@ def prompt_pairs(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for plain input/output prompt pairs"""
|
"""Module for plain input/output prompt pairs"""
|
||||||
|
|
||||||
from typing import Generator, Tuple
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for inspect jinja templates for the variables they use"""
|
"""Module for inspect jinja templates for the variables they use"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Set, TypedDict, Union
|
from typing import Dict, Optional, Set, TypedDict, Union
|
||||||
|
|
||||||
from jinja2 import Environment, meta, nodes
|
from jinja2 import Environment, meta, nodes
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
KTO strategies for chatml
|
KTO strategies for chatml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
@@ -15,9 +16,9 @@ def argilla(
|
|||||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -33,9 +34,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -74,9 +75,9 @@ def prompt_pairs(
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
KTO strategies for llama-3 chat template
|
KTO strategies for llama-3 chat template
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
@@ -15,9 +16,9 @@ def argilla(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -33,9 +34,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -74,9 +75,9 @@ def prompt_pairs(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample["prompt"] = (
|
||||||
"prompt"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
)
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
User-defined KTO strategies
|
User-defined KTO strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Chat dataset wrapping strategy for new internal messages representations
|
Chat dataset wrapping strategy for new internal messages representations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from axolotl.core.datasets.chat import TokenizedChatDataset
|
from axolotl.core.datasets.chat import TokenizedChatDataset
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
|
|||||||
|
|
||||||
Not suited/tested for multiple-turn conversations without further adjustments.
|
Not suited/tested for multiple-turn conversations without further adjustments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Generator, Union
|
from typing import Generator, Union
|
||||||
|
|
||||||
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""chatml prompt tokenization strategy for ORPO"""
|
"""chatml prompt tokenization strategy for ORPO"""
|
||||||
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pretraining prompt strategies"""
|
"""pretraining prompt strategies"""
|
||||||
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
|
|||||||
@@ -406,9 +406,7 @@ def handle_untrained_tokens_fix(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
||||||
) -> tuple[
|
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
|
|||||||
@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
if torch_major == 2 and torch_minor >= 2:
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
os.environ[
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||||
"PYTORCH_CUDA_ALLOC_CONF"
|
"expandable_segments:True,roundup_power2_divisions:16"
|
||||||
] = "expandable_segments:True,roundup_power2_divisions:16"
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
bench_refs.extend(combined_bench_names[bench_name]["refs"])
|
bench_refs.extend(combined_bench_names[bench_name]["refs"])
|
||||||
bench_preds.extend(combined_bench_names[bench_name]["preds"])
|
bench_preds.extend(combined_bench_names[bench_name]["preds"])
|
||||||
if not pd.isna(bench_score):
|
if not pd.isna(bench_score):
|
||||||
results[
|
results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
|
||||||
f"{bench_split}_bench_accuracy_{bench_name}"
|
bench_score
|
||||||
] = bench_score
|
)
|
||||||
bench_scores.append(bench_score)
|
bench_scores.append(bench_score)
|
||||||
else:
|
else:
|
||||||
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
|
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""callback to calculate perplexity as an evaluation metric."""
|
"""callback to calculate perplexity as an evaluation metric."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
HF Trainer callback for creating pytorch profiling snapshots
|
HF Trainer callback for creating pytorch profiling snapshots
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pickle import dump # nosec B403
|
from pickle import dump # nosec B403
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
This module provides functionality for selecting chat templates based on user choices.
|
This module provides functionality for selecting chat templates based on user choices.
|
||||||
These templates are used for formatting messages in a conversation.
|
These templates are used for formatting messages in a conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
shared axolotl collators for multipack, mamba, multimodal
|
shared axolotl collators for multipack, mamba, multimodal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .batching import ( # noqa: F401
|
from .batching import ( # noqa: F401
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
basic shared collator constants
|
basic shared collator constants
|
||||||
"""
|
"""
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
collators for Mamba
|
collators for Mamba
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Sequence
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,11 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
|||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
|
DPODataset,
|
||||||
|
KTODataset,
|
||||||
|
SFTDataset,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
|
|||||||
@@ -200,12 +200,12 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
message_field_role: Optional[
|
message_field_role: Optional[str] = (
|
||||||
str
|
None # deprecated, use message_property_mappings
|
||||||
] = None # deprecated, use message_property_mappings
|
)
|
||||||
message_field_content: Optional[
|
message_field_content: Optional[str] = (
|
||||||
str
|
None # deprecated, use message_property_mappings
|
||||||
] = None # deprecated, use message_property_mappings
|
)
|
||||||
message_property_mappings: Optional[Dict[str, str]] = None
|
message_property_mappings: Optional[Dict[str, str]] = None
|
||||||
message_field_training: Optional[str] = None
|
message_field_training: Optional[str] = None
|
||||||
message_field_training_detail: Optional[str] = None
|
message_field_training_detail: Optional[str] = None
|
||||||
@@ -505,9 +505,9 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr: Optional[float] = None
|
embedding_lr: Optional[float] = None
|
||||||
embedding_lr_scale: Optional[float] = None
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = (
|
||||||
Union[OptimizerNames, CustomSupportedOptimizers]
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
] = OptimizerNames.ADAMW_TORCH_FUSED
|
)
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
@@ -699,9 +699,9 @@ class AxolotlInputConfig(
|
|||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
dpo_use_weighting: Optional[
|
dpo_use_weighting: Optional[bool] = (
|
||||||
bool
|
None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
)
|
||||||
dpo_use_logits_to_keep: Optional[bool] = None
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[
|
datasets: Optional[
|
||||||
@@ -780,9 +780,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[
|
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = (
|
||||||
Union[Literal["unsloth", "offload"], bool]
|
Field(default=False)
|
||||||
] = Field(default=False)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
@@ -894,9 +894,9 @@ class AxolotlInputConfig(
|
|||||||
kto_undesirable_weight: Optional[float] = None
|
kto_undesirable_weight: Optional[float] = None
|
||||||
rl_beta: Optional[float] = None
|
rl_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = (
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
None
|
||||||
] = None
|
)
|
||||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||||
low_cpu_mem_usage: Optional[bool] = None
|
low_cpu_mem_usage: Optional[bool] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""module for gpu capabilities"""
|
"""module for gpu capabilities"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Data processing modules
|
Data processing modules
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pickle # nosec
|
import pickle # nosec
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
utils to get GPU info for the current environment
|
utils to get GPU info for the current environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from accelerate.utils.environment import (
|
from accelerate.utils.environment import (
|
||||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
)
|
)
|
||||||
from accelerate.utils.environment import get_gpu_info
|
from accelerate.utils.environment import (
|
||||||
|
get_gpu_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_p2p_ib_support():
|
def check_cuda_p2p_ib_support():
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
module to freeze/unfreeze parameters by name
|
module to freeze/unfreeze parameters by name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""custom checkpointing utils"""
|
"""custom checkpointing utils"""
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
module to handle loading model on cpu/meta device for FSDP
|
module to handle loading model on cpu/meta device for FSDP
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Type, Union
|
from typing import List, Optional, Type, Union
|
||||||
@@ -45,13 +46,13 @@ def _replace_linear(
|
|||||||
|
|
||||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||||
if issubclass(linear_replacement, Linear4bit):
|
if issubclass(linear_replacement, Linear4bit):
|
||||||
model._modules[ # pylint: disable=protected-access
|
model._modules[name] = ( # pylint: disable=protected-access
|
||||||
name
|
linear_replacement(
|
||||||
] = linear_replacement(
|
module.in_features,
|
||||||
module.in_features,
|
module.out_features,
|
||||||
module.out_features,
|
module.bias is not None,
|
||||||
module.bias is not None,
|
**kwargs,
|
||||||
**kwargs,
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -741,9 +741,9 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.cfg.gptq_disable_exllama is not None:
|
if self.cfg.gptq_disable_exllama is not None:
|
||||||
self.model_config.quantization_config[
|
self.model_config.quantization_config["disable_exllama"] = (
|
||||||
"disable_exllama"
|
self.cfg.gptq_disable_exllama
|
||||||
] = self.cfg.gptq_disable_exllama
|
)
|
||||||
self.model_kwargs["quantization_config"] = GPTQConfig(
|
self.model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Copied from https://github.com/iShohei220/adopt
|
|||||||
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||||
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
axolotl samplers module
|
axolotl samplers module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .multipack import MultipackBatchSampler # noqa: F401
|
from .multipack import MultipackBatchSampler # noqa: F401
|
||||||
from .utils import get_dataset_lengths # noqa: F401
|
from .utils import get_dataset_lengths # noqa: F401
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
helper util to calculate dataset lengths
|
helper util to calculate dataset lengths
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for custom LRScheduler class"""
|
"""Module for custom LRScheduler class"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|||||||
@@ -538,9 +538,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
if cfg.fsdp_config.fsdp_auto_wrap_policy:
|
if cfg.fsdp_config.fsdp_auto_wrap_policy:
|
||||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy
|
||||||
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
||||||
os.environ[
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
|
||||||
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
dynamic requirements for axolotl
|
dynamic requirements for axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
unit tests for generating sweep configurations
|
unit tests for generating sweep configurations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.cli.main import generate_sweep_configs
|
from axolotl.cli.main import generate_sweep_configs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI utils."""
|
"""pytest tests for axolotl CLI utils."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
shared pytest fixtures
|
shared pytest fixtures
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Tests for the chat messages module
|
Tests for the chat messages module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user