Compare commits
2 Commits
rl-trainer
...
feat/wizar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22684ec98f | ||
|
|
6db60ac520 |
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -31,11 +31,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -99,11 +94,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -295,7 +295,6 @@ jobs:
|
|||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
# Run this job first as a gate for running the remainder of the test matrix
|
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
@@ -342,8 +341,6 @@ jobs:
|
|||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
# Only run the remainder of the matrix if the first e2e check passed;
|
|
||||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -368,12 +365,6 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -139,8 +139,7 @@ quartodoc:
|
|||||||
- utils.optimizers.adopt
|
- utils.optimizers.adopt
|
||||||
- utils.data.pretraining
|
- utils.data.pretraining
|
||||||
- utils.data.sft
|
- utils.data.sft
|
||||||
- utils.gradient_checkpointing.offload_cpu
|
- utils.gradient_checkpointing.unsloth
|
||||||
- utils.gradient_checkpointing.offload_disk
|
|
||||||
- title: Schemas
|
- title: Schemas
|
||||||
desc: Pydantic data models for Axolotl config
|
desc: Pydantic data models for Axolotl config
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ train_on_inputs: false
|
|||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# Note that training loss may have an oscillating pattern with this enabled.
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
|||||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||||
format them.
|
format them.
|
||||||
|
|
||||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||||
format):
|
format):
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -120,12 +120,6 @@ axolotl train my_training.yml
|
|||||||
|
|
||||||
## Common Tasks {#sec-common-tasks}
|
## Common Tasks {#sec-common-tasks}
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
The same yaml file is used for training, inference, and merging.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Testing Your Model {#sec-testing}
|
### Testing Your Model {#sec-testing}
|
||||||
|
|
||||||
After training, test your model:
|
After training, test your model:
|
||||||
@@ -134,16 +128,6 @@ After training, test your model:
|
|||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
More details can be found in [Inference](inference.qmd).
|
|
||||||
|
|
||||||
### Using a UI {#sec-ui}
|
|
||||||
|
|
||||||
Launch a Gradio interface:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
|
||||||
```
|
|
||||||
|
|
||||||
### Preprocessing Data {#sec-preprocessing}
|
### Preprocessing Data {#sec-preprocessing}
|
||||||
|
|
||||||
For large datasets, preprocess first:
|
For large datasets, preprocess first:
|
||||||
@@ -152,22 +136,14 @@ For large datasets, preprocess first:
|
|||||||
axolotl preprocess my_training.yml
|
axolotl preprocess my_training.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
### Using a UI {#sec-ui}
|
||||||
|
|
||||||
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
Launch a Gradio interface:
|
||||||
|
|
||||||
### Merging LoRA weights {#sec-merging-lora}
|
|
||||||
|
|
||||||
To merge the LoRA weights back into the base model, run:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||||
```
|
```
|
||||||
|
|
||||||
The merged model will be saved in the `{output_dir}/merged` directory.
|
|
||||||
|
|
||||||
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
|
||||||
|
|
||||||
## Next Steps {#sec-next-steps}
|
## Next Steps {#sec-next-steps}
|
||||||
|
|
||||||
Now that you have the basics, you might want to:
|
Now that you have the basics, you might want to:
|
||||||
@@ -180,7 +156,6 @@ Now that you have the basics, you might want to:
|
|||||||
Check our other guides for details on these topics:
|
Check our other guides for details on these topics:
|
||||||
|
|
||||||
- [Configuration Guide](config.qmd) - Full configuration options
|
- [Configuration Guide](config.qmd) - Full configuration options
|
||||||
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
|
|
||||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||||
- [Multi-GPU Training](multi-gpu.qmd)
|
- [Multi-GPU Training](multi-gpu.qmd)
|
||||||
- [Multi-Node Training](multi-node.qmd)
|
- [Multi-Node Training](multi-node.qmd)
|
||||||
|
|||||||
@@ -342,6 +342,13 @@ def delinearize_llama4(model: str, output: str) -> None:
|
|||||||
do_delinearize_llama4(model, output)
|
do_delinearize_llama4(model, output)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def wizard():
|
||||||
|
from axolotl.cli.wizard import do_wizard
|
||||||
|
|
||||||
|
do_wizard()
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
429
src/axolotl/cli/wizard.py
Normal file
429
src/axolotl/cli/wizard.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
"""Wizard for creating yaml configs."""
|
||||||
|
|
||||||
|
import click
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from packaging import version
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
|
|
||||||
|
def do_wizard():
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
# Ask where to save the config
|
||||||
|
cfg = DictDefault({})
|
||||||
|
config_path = click.prompt(
|
||||||
|
"Where do you want to save the config?", type=str, default="config.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask base model
|
||||||
|
base_model = click.prompt("What base model do you want to use?", type=str)
|
||||||
|
cfg["base_model"] = base_model.strip()
|
||||||
|
|
||||||
|
# Ask whether want to enable Vision model
|
||||||
|
# TODO: check if model has vision layers instead of asking user
|
||||||
|
train_vision_model = click.confirm(
|
||||||
|
"If this model has vision layers, do you want to train them?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_vision_model:
|
||||||
|
cfg["processor_type"] = "AutoProcessor"
|
||||||
|
cfg["skip_prepare_dataset"] = True
|
||||||
|
cfg["remove_unused_columns"] = False
|
||||||
|
cfg["sample_packing"] = False
|
||||||
|
|
||||||
|
# Ask whether they want to set any advanced model features (custom tokenizer, custom config, etc)
|
||||||
|
advanced_model_features = click.confirm(
|
||||||
|
"Do you want to set any advanced model features? (custom tokenizer, custom config, remote code etc)",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if advanced_model_features:
|
||||||
|
# Ask whether they want to use a custom config
|
||||||
|
base_model_config = click.prompt(
|
||||||
|
"What model config do you want to use? (leave blank for default)",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_model_config:
|
||||||
|
cfg["base_model_config"] = base_model_config
|
||||||
|
|
||||||
|
# Ask whether they want to use a specific revision of the model
|
||||||
|
revision_of_model = click.prompt(
|
||||||
|
"What revision of the model do you want to use? (leave blank for default)",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
if revision_of_model:
|
||||||
|
cfg["revision_of_model"] = revision_of_model
|
||||||
|
|
||||||
|
# Ask whether they want to use a custom tokenizer
|
||||||
|
tokenizer_config = click.prompt(
|
||||||
|
"What tokenizer do you want to use? (leave blank for default)",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenizer_config:
|
||||||
|
cfg["tokenizer_config"] = tokenizer_config
|
||||||
|
|
||||||
|
# Ask whether they want to use remote code
|
||||||
|
trust_remote_code = click.confirm(
|
||||||
|
"Do you want to use remote code?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if trust_remote_code:
|
||||||
|
cfg["trust_remote_code"] = trust_remote_code
|
||||||
|
|
||||||
|
# Whether to resize token embeddings
|
||||||
|
resize_token_embeddings_to_32x = click.confirm(
|
||||||
|
"Do you want to resize token embeddings to 32x?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if resize_token_embeddings_to_32x:
|
||||||
|
cfg["resize_token_embeddings_to_32x"] = resize_token_embeddings_to_32x
|
||||||
|
|
||||||
|
# Whether to shrink embeddings to len(tokenizer)
|
||||||
|
shrink_embeddings = click.confirm(
|
||||||
|
"Do you want to shrink embeddings to len(tokenizer)?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if shrink_embeddings:
|
||||||
|
cfg["shrink_embeddings"] = shrink_embeddings
|
||||||
|
|
||||||
|
# Whether to skip upcast embeddings
|
||||||
|
embeddings_skip_upcast = click.confirm(
|
||||||
|
"Do you want to skip upcast embeddings?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if embeddings_skip_upcast:
|
||||||
|
cfg["embeddings_skip_upcast"] = embeddings_skip_upcast
|
||||||
|
|
||||||
|
# Whether to random init weights
|
||||||
|
random_init_weights = click.confirm(
|
||||||
|
"Do you want to random init weights?", default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if random_init_weights:
|
||||||
|
cfg["random_init_weights"] = random_init_weights
|
||||||
|
|
||||||
|
# Get model type
|
||||||
|
config = load_model_config(cfg)
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
|
# Ask sequence length
|
||||||
|
sequence_length = click.prompt("What sequence length do you want to use?", type=int)
|
||||||
|
cfg["sequence_length"] = sequence_length
|
||||||
|
|
||||||
|
# Whether to turn on sample packing
|
||||||
|
if cfg["sample_packing"] is None:
|
||||||
|
cfg["sample_packing"] = click.confirm(
|
||||||
|
"Do you want to turn on sample packing? This will speed up training by packing multiple samples into a single batch.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg["sample_packing"]:
|
||||||
|
cfg["pad_to_sequence_len"] = True
|
||||||
|
|
||||||
|
# Whether to turn off eval sample packing
|
||||||
|
no_eval_sample_packing = click.confirm(
|
||||||
|
"Do you want to turn off eval sample packing? This will slow down evaluation but is recommended if you are using a small validation set.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if no_eval_sample_packing:
|
||||||
|
cfg["eval_sample_packing"] = False
|
||||||
|
|
||||||
|
# Hardware check
|
||||||
|
try:
|
||||||
|
is_ampere_or_newer = torch.cuda.get_device_capability()[0] >= 8
|
||||||
|
except RuntimeError:
|
||||||
|
is_ampere_or_newer = False
|
||||||
|
except AssertionError: # this is raised if no cuda is available
|
||||||
|
is_ampere_or_newer = False
|
||||||
|
|
||||||
|
# Get num gpus
|
||||||
|
try:
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
except RuntimeError:
|
||||||
|
num_gpus = 0
|
||||||
|
|
||||||
|
# Get torch version
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
|
||||||
|
is_torch_2_6_or_newer = version.parse(torch_version) >= version.parse("2.6.0")
|
||||||
|
|
||||||
|
# Whether to turn on attention
|
||||||
|
opt = ["xformers", "sdp"]
|
||||||
|
|
||||||
|
if is_ampere_or_newer:
|
||||||
|
opt.append("flash")
|
||||||
|
|
||||||
|
if is_torch_2_6_or_newer:
|
||||||
|
opt.append("flex")
|
||||||
|
|
||||||
|
if cfg["sample_packing"]:
|
||||||
|
if "flash" in opt:
|
||||||
|
default_opt = "flash"
|
||||||
|
elif "flex" in opt:
|
||||||
|
default_opt = "flex"
|
||||||
|
else:
|
||||||
|
default_opt = opt[0]
|
||||||
|
|
||||||
|
attention = click.prompt(
|
||||||
|
"Which attention backend do you want to use? Sample packing requires an attention backend to be set.",
|
||||||
|
type=click.Choice(opt),
|
||||||
|
default=default_opt,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# non-sample packing supports no attention and S2
|
||||||
|
opt.extend(["none", "s2"])
|
||||||
|
|
||||||
|
attention = click.prompt(
|
||||||
|
"Which attention backend do you want to use?",
|
||||||
|
type=click.Choice(opt),
|
||||||
|
default="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention == "none":
|
||||||
|
attention = None
|
||||||
|
|
||||||
|
# TODO: if xformers, check if FA is installed
|
||||||
|
# TODO: flex doc mentioned requiring seq len to be divisible by 128. Unclear if limitation still exists
|
||||||
|
|
||||||
|
# TODO: requires #2489
|
||||||
|
cfg["attention"] = attention
|
||||||
|
|
||||||
|
# Whether to turn on gradient checkpointing
|
||||||
|
# TODO: need to wait for offload_disk PR to be merged
|
||||||
|
gradient_checkpointing = click.prompt(
|
||||||
|
"Which gradient checkpointing strategy do you want to use?",
|
||||||
|
type=click.Choice(["none", "true", "offload", "offload_disk"]),
|
||||||
|
default="true",
|
||||||
|
)
|
||||||
|
|
||||||
|
if gradient_checkpointing == "none":
|
||||||
|
gradient_checkpointing = False
|
||||||
|
elif gradient_checkpointing == "true":
|
||||||
|
gradient_checkpointing = True
|
||||||
|
|
||||||
|
# Ask whether to set use_reentrant
|
||||||
|
# TODO: get correct defaults based on SFT/RL mode and single/multigpu
|
||||||
|
# use_reentrant = click.confirm(
|
||||||
|
# "Do you want to set use_reentrant?",
|
||||||
|
# default=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if use_reentrant:
|
||||||
|
# cfg["use_reentrant"] = use_reentrant
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
cfg["optimizer"] = click.prompt(
|
||||||
|
"Which optimizer do you want to use?",
|
||||||
|
type=click.Choice((OptimizerNames | CustomSupportedOptimizers)),
|
||||||
|
default=OptimizerNames.ADAMW_TORCH_FUSED,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg["lr_scheduler"] = click.prompt(
|
||||||
|
"Which learning rate scheduler do you want to use?",
|
||||||
|
type=click.Choice(
|
||||||
|
[
|
||||||
|
"cosine",
|
||||||
|
"one_cycle",
|
||||||
|
"rex",
|
||||||
|
"log_sweep",
|
||||||
|
"linear",
|
||||||
|
"cosine_with_restarts",
|
||||||
|
"polynomial",
|
||||||
|
"constant",
|
||||||
|
"constant_with_warmup",
|
||||||
|
"inverse_sqrt",
|
||||||
|
"reduce_lr_on_plateau",
|
||||||
|
"cosine_with_min_lr",
|
||||||
|
"warmup_stable_decay",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
default="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plugins
|
||||||
|
|
||||||
|
cfg["plugins"] = []
|
||||||
|
|
||||||
|
# Whether to turn on cut cross entropy
|
||||||
|
if is_ampere_or_newer:
|
||||||
|
# Note: This may error if users don't have CCE installed
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||||
|
CUT_CROSS_ENTROPY_MODEL_MAPPING,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||||
|
cut_cross_entropy = click.confirm(
|
||||||
|
"Do you want to turn on cut cross entropy? This will save VRAM if the model has a large vocab size.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cut_cross_entropy:
|
||||||
|
cfg["plugins"].append(
|
||||||
|
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg["cut_cross_entropy"] = True
|
||||||
|
|
||||||
|
use_liger_kernel = click.confirm(
|
||||||
|
"Do you want to use the liger kernel? This will speed up training and save VRAM.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_liger_kernel:
|
||||||
|
cfg["plugins"].append("axolotl.integrations.liger.LigerPlugin")
|
||||||
|
|
||||||
|
cfg["liger_rope"] = click.confirm(
|
||||||
|
"Do you want to enable liger rope?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg["liger_rms_norm"] = click.confirm(
|
||||||
|
"Do you want to enable liger rms norm?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg["liger_glu_activation"] = click.confirm(
|
||||||
|
"Do you want to enable liger glu activation?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg["liger_layer_norm"] = click.confirm(
|
||||||
|
"Do you want to enable liger layer norm?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg["cut_cross_entropy"] is not True:
|
||||||
|
cfg["liger_fused_linear_cross_entropy"] = click.confirm(
|
||||||
|
"Do you want to enable liger fused linear cross entropy?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: lora kernels (but they auto enable via validator already)
|
||||||
|
|
||||||
|
# TODO: is there incompat between torch compile and liger?
|
||||||
|
cfg["torch_compile"] = click.confirm(
|
||||||
|
"Do you want to enable torch compile?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-gpu
|
||||||
|
if num_gpus > 1:
|
||||||
|
# Ask whether to use DDP/Deepspeed/FSDP
|
||||||
|
multi_gpu_mode = click.prompt(
|
||||||
|
"Which multi-gpu mode do you want to use?",
|
||||||
|
type=click.Choice(["ddp", "deepspeed", "fsdp"]),
|
||||||
|
default="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
if multi_gpu_mode == "deepspeed":
|
||||||
|
# Ask which deepspeed config to use
|
||||||
|
cfg["deepspeed"] = click.prompt(
|
||||||
|
"Which deepspeed config do you want to use? The higher the number, the more VRAM you will save, but the slower it will run.",
|
||||||
|
type=click.Choice(
|
||||||
|
[
|
||||||
|
"zero1.json",
|
||||||
|
"zero1_torch_compile.json",
|
||||||
|
"zero2.json",
|
||||||
|
"zero3.json",
|
||||||
|
"zero3_bf16.json",
|
||||||
|
"zero3_bf16_cpuoffload_all.json",
|
||||||
|
"zero3_bf16_cpuoffload_params.json",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
default="zero1.json",
|
||||||
|
)
|
||||||
|
elif multi_gpu_mode == "fsdp":
|
||||||
|
fsdp_version = click.prompt(
|
||||||
|
"Which fsdp version do you want to use?",
|
||||||
|
type=click.Choice([1, 2]),
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Handle FSDP config
|
||||||
|
|
||||||
|
if fsdp_version == 1:
|
||||||
|
cfg["fsdp"] = ["full_shard", "auto_wrap"]
|
||||||
|
|
||||||
|
# Ask which state dict type to use
|
||||||
|
fsdp_state_dict_type = click.prompt(
|
||||||
|
"Which fsdp state dict type do you want to use?",
|
||||||
|
type=click.Choice(["FULL_STATE_DICT", "SHARDED_STATE_DICT"]),
|
||||||
|
default="FULL_STATE_DICT",
|
||||||
|
)
|
||||||
|
|
||||||
|
fsdp_offload_params = click.confirm(
|
||||||
|
"Do you want to offload parameters?",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: can we load the model class and auto pull a default for this?
|
||||||
|
fsdp_transformer_layer_cls_to_wrap = click.prompt(
|
||||||
|
"Which transformer layer class to wrap? It is usually the Decoder layer class.",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add other options
|
||||||
|
|
||||||
|
cfg["fsdp_config"] = {
|
||||||
|
"state_dict_type": fsdp_state_dict_type,
|
||||||
|
"offload_params": fsdp_offload_params,
|
||||||
|
"transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif fsdp_version == 2:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# Training mode (sft or rl)
|
||||||
|
training_mode = click.prompt(
|
||||||
|
"Which training mode do you want to use?",
|
||||||
|
type=click.Choice(["sft", "rl"]),
|
||||||
|
default="sft",
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_mode == "rl":
|
||||||
|
cfg["rl"] = click.prompt(
|
||||||
|
"Which rl mode do you want to use?",
|
||||||
|
type=click.Choice(["dpo", "ipo", "orpo", "kto", "grpo", "simpo"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: handle RL options
|
||||||
|
|
||||||
|
# Whether to use adapter
|
||||||
|
|
||||||
|
# Get batch/grad accu
|
||||||
|
|
||||||
|
# Get learning rate
|
||||||
|
|
||||||
|
# Get weight decay
|
||||||
|
|
||||||
|
# Get max grad norm
|
||||||
|
|
||||||
|
# Get num train epochs
|
||||||
|
|
||||||
|
# Get warmup ratio
|
||||||
|
|
||||||
|
# Get save ratio
|
||||||
|
|
||||||
|
# Get eval ratio
|
||||||
|
|
||||||
|
# Get dataset config
|
||||||
|
|
||||||
|
# Load metric tracker
|
||||||
|
|
||||||
|
# Save config to yaml
|
||||||
|
# TODO: improve output yaml formatting. Need to add comments to help separate sections
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(cfg.to_dict(), f, sort_keys=False)
|
||||||
@@ -156,9 +156,6 @@ class AxolotlTrainer(
|
|||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
and sample packing cases.
|
and sample packing cases.
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: Evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
@@ -240,6 +237,9 @@ class AxolotlTrainer(
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,33 @@
|
|||||||
"""DPO trainer for Axolotl"""
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import random
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
import wandb
|
||||||
|
from accelerate import PartialState
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import DataLoader
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
Trainer,
|
Trainer,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_utils import EvalLoopOutput
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
|
||||||
|
from trl.trainer.utils import log_table_to_comet_experiment
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
RngLoaderMixin,
|
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -29,10 +37,10 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(
|
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||||
RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, DPOTrainer
|
"""
|
||||||
):
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
@@ -87,6 +95,64 @@ class AxolotlDPOTrainer(
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def _prepare_dataset(
|
||||||
|
self,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
processing_class: Union[
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
ProcessorMixin,
|
||||||
|
],
|
||||||
|
args: DPOConfig,
|
||||||
|
dataset_name: str,
|
||||||
|
) -> Union[Dataset, IterableDataset]:
|
||||||
|
# Build the kwargs for the `map` function
|
||||||
|
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
|
||||||
|
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||||
|
map_kwargs["num_proc"] = args.dataset_num_proc
|
||||||
|
|
||||||
|
with PartialState().main_process_first():
|
||||||
|
# Extract prompt if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
||||||
|
|
||||||
|
# Apply the chat template if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the dataset
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
self.tokenize_row if not self.is_vision_model else self.process_row,
|
||||||
|
remove_columns=["chosen", "rejected"],
|
||||||
|
fn_kwargs={
|
||||||
|
"processing_class": processing_class,
|
||||||
|
"max_prompt_length": args.max_prompt_length,
|
||||||
|
"max_completion_length": args.max_completion_length,
|
||||||
|
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
||||||
|
"add_special_tokens": False,
|
||||||
|
},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
@@ -127,48 +193,68 @@ class AxolotlDPOTrainer(
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def evaluation_loop(
|
||||||
|
self,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
description: str,
|
||||||
|
prediction_loss_only: Optional[bool] = None,
|
||||||
|
ignore_keys: Optional[list[str]] = None,
|
||||||
|
metric_key_prefix: str = "eval",
|
||||||
|
) -> EvalLoopOutput:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
Overriding built-in evaluation loop to store metrics for each batch.
|
||||||
parallelism, sample packing, and curriculum sampling (sequential).
|
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||||
|
|
||||||
Returns:
|
Works both with or without labels.
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
"""
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
# Sample and save to game log if requested (for one batch to save time)
|
||||||
import ipdb
|
if self.generate_during_eval:
|
||||||
|
# Generate random indices within the range of the total number of samples
|
||||||
|
num_samples = len(dataloader.dataset)
|
||||||
|
random_indices = random.sample(
|
||||||
|
range(num_samples), k=self.args.eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
ipdb.set_trace()
|
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
||||||
dist.barrier()
|
random_batch_dataset = dataloader.dataset.select(random_indices)
|
||||||
if dist.get_rank() == 1:
|
random_batch = self.data_collator(random_batch_dataset)
|
||||||
import ipdb
|
random_batch = self._prepare_inputs(random_batch)
|
||||||
|
|
||||||
ipdb.set_trace()
|
policy_output_decoded, ref_output_decoded = (
|
||||||
dist.barrier()
|
self.generate_from_model_and_ref(self.model, random_batch)
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
table = pd.DataFrame(
|
||||||
return self._sp_get_train_sampler(self.train_dataset)
|
columns=["Prompt", "Policy", "Ref Model"],
|
||||||
|
data=[
|
||||||
|
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||||
|
for prompt, pol, ref in zip(
|
||||||
|
random_batch_dataset["prompt"],
|
||||||
|
policy_output_decoded,
|
||||||
|
ref_output_decoded,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
||||||
|
wandb.log({"game_log": wandb.Table(data=table)})
|
||||||
|
|
||||||
return super()._get_train_sampler()
|
if "comet_ml" in self.args.report_to:
|
||||||
|
log_table_to_comet_experiment(
|
||||||
|
name="game_log.csv",
|
||||||
|
table=table,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
# Base evaluation
|
||||||
"""
|
initial_output = super( # pylint: disable=bad-super-call
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
DPOTrainer, self
|
||||||
and sample packing cases.
|
).evaluation_loop(
|
||||||
|
dataloader,
|
||||||
|
description,
|
||||||
|
prediction_loss_only,
|
||||||
|
ignore_keys,
|
||||||
|
metric_key_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
return initial_output
|
||||||
eval_dataset: Evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return self._sp_get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|||||||
@@ -266,6 +266,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -20,15 +20,25 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
COHERE_INPUTS_DOCSTRING,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -17,15 +17,25 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA_INPUTS_DOCSTRING,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -20,11 +20,15 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache, HybridCache
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA3_INPUTS_DOCSTRING,
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -34,6 +38,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -162,6 +170,10 @@ def cce_forward(
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -19,9 +19,15 @@ from transformers.modeling_outputs import (
|
|||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
LLAMA_INPUTS_DOCSTRING,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -30,6 +36,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -16,12 +16,22 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama4.modeling_llama4 import (
|
from transformers.models.llama4.modeling_llama4 import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
LLAMA4_INPUTS_DOCSTRING,
|
||||||
Llama4CausalLMOutputWithPast,
|
Llama4CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -150,6 +160,9 @@ def cce_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||||
|
|||||||
@@ -19,11 +19,15 @@ from transformers.models.mistral3.modeling_mistral3 import (
|
|||||||
Mistral3CausalLMOutputWithPast,
|
Mistral3CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
MISTRAL_INPUTS_DOCSTRING,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -31,6 +35,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -13,10 +13,16 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN2MOE_INPUTS_DOCSTRING,
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -25,6 +31,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -14,12 +14,22 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
)
|
)
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN2_VL_INPUTS_DOCSTRING,
|
||||||
Qwen2VLCausalLMOutputWithPast,
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -12,13 +12,20 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
TransformersModelT,
|
TransformersModelT,
|
||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN3_MOE_INPUTS_DOCSTRING,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -27,6 +34,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
||||||
|
# @replace_return_docstrings(
|
||||||
|
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
# )
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -13,11 +13,21 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
from transformers.models.jamba.modeling_jamba import (
|
from transformers.models.jamba.modeling_jamba import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
JAMBA_INPUTS_DOCSTRING,
|
||||||
HybridMambaAttentionDynamicCache,
|
HybridMambaAttentionDynamicCache,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -7,16 +7,24 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA3_INPUTS_DOCSTRING,
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
def new_forward(
|
def new_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -33,7 +32,7 @@ def apply_sequence_parallelism(
|
|||||||
to only keep the last N tokens in the sequence during generation.
|
to only keep the last N tokens in the sequence during generation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Dictionary of model arguments (e.g., input_ids, attention_mask, etc.).
|
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||||
local_rank: Local rank in the sequence parallel group.
|
local_rank: Local rank in the sequence parallel group.
|
||||||
local_world_size: World size of the sequence parallel group.
|
local_world_size: World size of the sequence parallel group.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
@@ -207,26 +206,12 @@ class SequenceParallelContextManager:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
# Convert all args to kwargs using the model's forward function signature
|
# Apply sequence parallelism to kwargs and get original sequence length and padding info
|
||||||
updated_kwargs = kwargs.copy()
|
kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
self.apply_sequence_parallelism(batch=kwargs)
|
||||||
# Get parameter names from the model's forward function
|
|
||||||
forward_params = list(
|
|
||||||
inspect.signature(self.models[0].forward).parameters.keys()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map args to their parameter names
|
return args, kwargs
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i < len(forward_params):
|
|
||||||
param_name = forward_params[i]
|
|
||||||
updated_kwargs[param_name] = arg
|
|
||||||
|
|
||||||
# Apply sequence parallelism to empty args and updated kwargs
|
|
||||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
|
||||||
self.apply_sequence_parallelism(updated_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
return (), updated_kwargs
|
|
||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -484,7 +484,7 @@ def get_dataset_wrapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -5,11 +5,8 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
CPU_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
|
||||||
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
|
||||||
Disco,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||||
@@ -29,31 +26,12 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if uses_gc_layers(decoder_layer):
|
if uses_gc_layers(decoder_layer):
|
||||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer,
|
decoder_layer,
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
(
|
|
||||||
decoder_layer.func.__self__
|
|
||||||
if isinstance(decoder_layer, partial)
|
|
||||||
else decoder_layer.__self__
|
|
||||||
),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hf_grad_checkpoint_disk_offload_wrapper(
|
|
||||||
decoder_layer, *args, use_reentrant=None
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
if uses_gc_layers(decoder_layer):
|
|
||||||
return Disco.apply(
|
|
||||||
decoder_layer,
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Disco.apply(
|
|
||||||
(
|
(
|
||||||
decoder_layer.func.__self__
|
decoder_layer.func.__self__
|
||||||
if isinstance(decoder_layer, partial)
|
if isinstance(decoder_layer, partial)
|
||||||
|
|||||||
@@ -1,531 +0,0 @@
|
|||||||
"""
|
|
||||||
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Copyright 2025 Axolotl AI. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import atexit
|
|
||||||
import concurrent.futures
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections import deque
|
|
||||||
from concurrent.futures import Future
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
|
||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
|
||||||
|
|
||||||
# Setup logger
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DiskOffloadManager:
|
|
||||||
"""
|
|
||||||
Manages offloaded tensors and handles prefetching in a separate thread.
|
|
||||||
Includes synchronization to prevent race conditions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefetch_size: int = 3,
|
|
||||||
prefetch_to_gpu: bool = True,
|
|
||||||
save_workers: int = 4,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
prefetch_size: Maximum number of tensors to prefetch in the background.
|
|
||||||
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
|
||||||
save_workers: Maximum number of concurrent save operations.
|
|
||||||
"""
|
|
||||||
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
|
||||||
|
|
||||||
# Track tensor paths and their status
|
|
||||||
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
|
||||||
self.file_locks: Dict[str, threading.Lock] = (
|
|
||||||
{}
|
|
||||||
) # Maps file_path -> threading.Lock()
|
|
||||||
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
|
||||||
self.file_status: Dict[str, str] = {}
|
|
||||||
|
|
||||||
self.max_prefetch = prefetch_size
|
|
||||||
self.prefetch_to_gpu = prefetch_to_gpu
|
|
||||||
|
|
||||||
# Thread synchronization
|
|
||||||
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
|
||||||
|
|
||||||
# Prefetch queue and cache
|
|
||||||
self.prefetch_queue: queue.Queue = queue.Queue()
|
|
||||||
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
|
||||||
|
|
||||||
# Save queue and thread pool
|
|
||||||
self.save_queue: queue.Queue = queue.Queue()
|
|
||||||
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
|
||||||
self.save_futures: Dict[str, Future] = {}
|
|
||||||
self.save_semaphore = threading.Semaphore(
|
|
||||||
save_workers * 2
|
|
||||||
) # Limit concurrent save operations
|
|
||||||
|
|
||||||
# Start prefetch worker thread
|
|
||||||
self.stop_event = threading.Event()
|
|
||||||
# start multiple threads for prefetching
|
|
||||||
self.prefetch_worker_count = 2
|
|
||||||
self.prefetch_workers = []
|
|
||||||
for _ in range(self.prefetch_worker_count):
|
|
||||||
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
|
||||||
worker.start()
|
|
||||||
self.prefetch_workers.append(worker)
|
|
||||||
|
|
||||||
# Start save worker thread
|
|
||||||
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
|
||||||
self.save_worker.start()
|
|
||||||
self.idx = 0
|
|
||||||
|
|
||||||
atexit.register(self.cleanup)
|
|
||||||
|
|
||||||
def _save_worker(self):
|
|
||||||
"""Background thread that processes the save queue"""
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
save_item = self.save_queue.get(timeout=0.5)
|
|
||||||
if save_item is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tensor, file_path = save_item
|
|
||||||
|
|
||||||
# Submit the save task to the thread pool
|
|
||||||
future = self.save_pool.submit(
|
|
||||||
self._save_tensor_to_disk, tensor, file_path
|
|
||||||
)
|
|
||||||
with self.manager_lock:
|
|
||||||
self.save_futures[file_path] = future
|
|
||||||
|
|
||||||
self.save_queue.task_done()
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
|
||||||
continue
|
|
||||||
|
|
||||||
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
|
||||||
"""Actually save the tensor to disk"""
|
|
||||||
try:
|
|
||||||
# Save tensor to disk
|
|
||||||
cpu_tensor = tensor.detach().cpu()
|
|
||||||
torch.save(cpu_tensor, file_path)
|
|
||||||
del cpu_tensor
|
|
||||||
|
|
||||||
with self.manager_lock:
|
|
||||||
# Mark file as ready
|
|
||||||
self.file_status[file_path] = "ready"
|
|
||||||
|
|
||||||
# Release semaphore
|
|
||||||
self.save_semaphore.release()
|
|
||||||
|
|
||||||
return True
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.error(f"Error saving tensor to {file_path}: {e}")
|
|
||||||
with self.manager_lock:
|
|
||||||
self.file_status[file_path] = "error"
|
|
||||||
|
|
||||||
# Release semaphore
|
|
||||||
self.save_semaphore.release()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _prefetch_worker(self):
|
|
||||||
"""Background thread that loads tensors from disk ahead of time"""
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
file_path = self.prefetch_queue.get(timeout=0.5)
|
|
||||||
if file_path is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if file is available and not already in cache
|
|
||||||
with self.manager_lock:
|
|
||||||
if (
|
|
||||||
file_path not in self.file_status
|
|
||||||
or self.file_status[file_path] == "deleted"
|
|
||||||
):
|
|
||||||
self.prefetch_queue.task_done()
|
|
||||||
if file_path in self.prefetch_cache:
|
|
||||||
self.prefetch_queue.task_done()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If file is still being saved, wait for it
|
|
||||||
if (
|
|
||||||
self.file_status[file_path] == "saving"
|
|
||||||
and file_path in self.save_futures
|
|
||||||
):
|
|
||||||
# Re-queue this prefetch request with a little delay
|
|
||||||
self.prefetch_queue.task_done()
|
|
||||||
time.sleep(0.1)
|
|
||||||
self.prefetch_queue.put(file_path)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Mark file as being prefetched
|
|
||||||
self.file_status[file_path] = "prefetching"
|
|
||||||
|
|
||||||
# Load tensor from disk and store in cache
|
|
||||||
try:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
if self.prefetch_to_gpu:
|
|
||||||
tensor = torch.load(
|
|
||||||
file_path,
|
|
||||||
map_location=torch.device("cuda"),
|
|
||||||
weights_only=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tensor = torch.load(file_path, weights_only=True)
|
|
||||||
|
|
||||||
with self.manager_lock:
|
|
||||||
self.prefetch_cache[file_path] = tensor
|
|
||||||
self.file_status[file_path] = "ready"
|
|
||||||
else:
|
|
||||||
with self.manager_lock:
|
|
||||||
if self.file_status.get(file_path) != "deleted":
|
|
||||||
logger.warning(
|
|
||||||
f"Prefetch error: File not found {file_path}"
|
|
||||||
)
|
|
||||||
self.file_status[file_path] = "missing"
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
with self.manager_lock:
|
|
||||||
if self.file_status.get(file_path) != "deleted":
|
|
||||||
logger.warning(f"Prefetch error for {file_path}: {e}")
|
|
||||||
self.file_status[file_path] = "error"
|
|
||||||
|
|
||||||
self.prefetch_queue.task_done()
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
|
||||||
continue
|
|
||||||
|
|
||||||
def save_tensor(self, tensor: torch.Tensor):
|
|
||||||
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
|
||||||
# Generate unique file path
|
|
||||||
self.idx += 1
|
|
||||||
file_path: str = os.path.join(
|
|
||||||
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.manager_lock:
|
|
||||||
# Mark file as being saved
|
|
||||||
self.file_locks[file_path] = threading.Lock()
|
|
||||||
self.file_status[file_path] = "saving"
|
|
||||||
# Add to history
|
|
||||||
self.tensor_paths.append(file_path)
|
|
||||||
|
|
||||||
# Acquire semaphore to limit concurrent save operations
|
|
||||||
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
|
||||||
# Queue tensor for saving in background
|
|
||||||
self.save_queue.put((tensor.detach(), file_path))
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
def wait_for_save(self, file_path, timeout=None) -> None:
|
|
||||||
"""Wait for a tensor to be saved to disk"""
|
|
||||||
start_time = time.time()
|
|
||||||
while timeout is None or time.time() - start_time < timeout:
|
|
||||||
with self.manager_lock:
|
|
||||||
if self.file_status.get(file_path) == "ready":
|
|
||||||
return
|
|
||||||
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
if file_path in self.save_futures:
|
|
||||||
future = self.save_futures[file_path]
|
|
||||||
if future.done():
|
|
||||||
return
|
|
||||||
|
|
||||||
# Small sleep to prevent CPU spinning
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
# Timeout
|
|
||||||
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
def load_tensor(self, file_path, target_device="cuda"):
|
|
||||||
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
|
||||||
# Wait for tensor to be saved if it's still in progress
|
|
||||||
self.wait_for_save(file_path)
|
|
||||||
|
|
||||||
tensor = None
|
|
||||||
|
|
||||||
# Try to get from cache first
|
|
||||||
with self.manager_lock:
|
|
||||||
# Check if tensor is already in cache
|
|
||||||
if file_path in self.prefetch_cache:
|
|
||||||
tensor = self.prefetch_cache[file_path]
|
|
||||||
del self.prefetch_cache[file_path]
|
|
||||||
self.file_status[file_path] = "loaded"
|
|
||||||
|
|
||||||
if tensor is not None:
|
|
||||||
# Ensure tensor is on correct device
|
|
||||||
if target_device != "cpu" and tensor.device.type == "cpu":
|
|
||||||
tensor = tensor.to(target_device, non_blocking=True)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
# If not in cache, load directly from disk
|
|
||||||
try:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
logger.error(f"File not found for loading: {file_path}")
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
tensor = torch.load(file_path, weights_only=True)
|
|
||||||
|
|
||||||
with self.manager_lock:
|
|
||||||
self.file_status[file_path] = "loaded"
|
|
||||||
|
|
||||||
if target_device != "cpu":
|
|
||||||
tensor = tensor.to(target_device, non_blocking=True)
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading tensor from {file_path}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _safe_delete_file(self, file_path):
|
|
||||||
"""Safely delete a file with proper synchronization"""
|
|
||||||
with self.manager_lock:
|
|
||||||
# Make sure any save operation is completed
|
|
||||||
if file_path in self.save_futures:
|
|
||||||
future = self.save_futures[file_path]
|
|
||||||
try:
|
|
||||||
if not future.done():
|
|
||||||
future.cancel()
|
|
||||||
del self.save_futures[file_path]
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Error canceling save operation for {file_path}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only delete if file exists and is not being prefetched
|
|
||||||
status = self.file_status.get(file_path)
|
|
||||||
if status in ["ready", "loaded", "error", "missing"]:
|
|
||||||
try:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
self.file_status[file_path] = "deleted"
|
|
||||||
return True
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.warning(f"Error deleting file {file_path}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def trigger_prefetch(self, n=None):
|
|
||||||
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
|
||||||
if n is None:
|
|
||||||
n = self.max_prefetch
|
|
||||||
|
|
||||||
prefetch_paths = []
|
|
||||||
with self.manager_lock:
|
|
||||||
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
|
||||||
for path in reversed(self.tensor_paths):
|
|
||||||
if (
|
|
||||||
path not in self.prefetch_cache
|
|
||||||
and self.file_status.get(path) == "ready"
|
|
||||||
):
|
|
||||||
prefetch_paths.append(path)
|
|
||||||
if len(prefetch_paths) >= n:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Queue files for prefetching
|
|
||||||
for path in prefetch_paths:
|
|
||||||
self.prefetch_queue.put(path)
|
|
||||||
|
|
||||||
def cleanup_tensor(self, file_path: str):
|
|
||||||
"""Clean up a specific tensor file after it's been used"""
|
|
||||||
with self.manager_lock:
|
|
||||||
if file_path in self.tensor_paths:
|
|
||||||
self.tensor_paths.remove(file_path)
|
|
||||||
|
|
||||||
# Remove from prefetch cache if present
|
|
||||||
if file_path in self.prefetch_cache:
|
|
||||||
del self.prefetch_cache[file_path]
|
|
||||||
|
|
||||||
# Remove from save futures if present
|
|
||||||
if file_path in self.save_futures:
|
|
||||||
future = self.save_futures[file_path]
|
|
||||||
if not future.done():
|
|
||||||
future.cancel()
|
|
||||||
del self.save_futures[file_path]
|
|
||||||
|
|
||||||
# Try to delete the file
|
|
||||||
self._safe_delete_file(file_path)
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
|
||||||
self.stop_event.set()
|
|
||||||
|
|
||||||
# Cancel all pending save operations
|
|
||||||
with self.manager_lock:
|
|
||||||
for _, future in self.save_futures.items():
|
|
||||||
if not future.done():
|
|
||||||
future.cancel()
|
|
||||||
self.save_futures.clear()
|
|
||||||
|
|
||||||
# Drain the save queue
|
|
||||||
while not self.save_queue.empty():
|
|
||||||
try:
|
|
||||||
self.save_queue.get_nowait()
|
|
||||||
self.save_queue.task_done()
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Shutdown the save pool
|
|
||||||
self.save_pool.shutdown(wait=False)
|
|
||||||
|
|
||||||
# Join the save worker thread
|
|
||||||
if self.save_worker.is_alive():
|
|
||||||
self.save_worker.join(timeout=2.0)
|
|
||||||
|
|
||||||
# Join the prefetch worker threads
|
|
||||||
for thread in self.prefetch_workers:
|
|
||||||
if thread.is_alive():
|
|
||||||
thread.join(timeout=2.0)
|
|
||||||
|
|
||||||
# Clear cache and remove all temporary files
|
|
||||||
with self.manager_lock:
|
|
||||||
self.prefetch_cache.clear()
|
|
||||||
paths_to_delete = list(self.tensor_paths)
|
|
||||||
self.tensor_paths.clear()
|
|
||||||
|
|
||||||
# Delete all temporary files
|
|
||||||
for path in paths_to_delete:
|
|
||||||
self._safe_delete_file(path)
|
|
||||||
|
|
||||||
# Remove temp directory
|
|
||||||
try:
|
|
||||||
if os.path.exists(self.temp_dir):
|
|
||||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class Disco(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
|
||||||
Advanced disk-based gradient checkpointer with prefetching.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Shared manager instance across all checkpointing operations
|
|
||||||
_manager = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
|
||||||
"""Get or create the offload manager"""
|
|
||||||
if Disco._manager is None:
|
|
||||||
Disco._manager = DiskOffloadManager(
|
|
||||||
prefetch_size=prefetch_size,
|
|
||||||
prefetch_to_gpu=prefetch_to_gpu,
|
|
||||||
save_workers=save_workers,
|
|
||||||
)
|
|
||||||
return Disco._manager
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch_cuda_amp_custom_fwd
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
forward_function,
|
|
||||||
hidden_states,
|
|
||||||
*args,
|
|
||||||
prefetch_size=1,
|
|
||||||
prefetch_to_gpu=True,
|
|
||||||
save_workers=4,
|
|
||||||
):
|
|
||||||
"""Forward pass that offloads activations to disk asynchronously"""
|
|
||||||
# Get or create the manager
|
|
||||||
manager = Disco.get_instance(
|
|
||||||
prefetch_size=prefetch_size,
|
|
||||||
prefetch_to_gpu=prefetch_to_gpu,
|
|
||||||
save_workers=save_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save tensor to disk asynchronously
|
|
||||||
file_path = manager.save_tensor(hidden_states)
|
|
||||||
|
|
||||||
# Run forward pass immediately without waiting for save to complete
|
|
||||||
with torch.no_grad():
|
|
||||||
output = forward_function(hidden_states, *args)
|
|
||||||
|
|
||||||
# Store what we need for backward
|
|
||||||
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
|
||||||
ctx.file_path = file_path
|
|
||||||
ctx.forward_function = forward_function
|
|
||||||
ctx.args = args
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch_cuda_amp_custom_bwd
|
|
||||||
def backward(ctx, *grad_outputs):
|
|
||||||
"""Backward pass that loads activations from disk with prefetching"""
|
|
||||||
# Get the manager
|
|
||||||
manager = Disco._manager
|
|
||||||
|
|
||||||
# Trigger prefetching for future tensors
|
|
||||||
# This happens at the start of backward, so should have time to complete
|
|
||||||
manager.trigger_prefetch()
|
|
||||||
|
|
||||||
# Load hidden states from disk or prefetch cache
|
|
||||||
file_path = ctx.file_path
|
|
||||||
try:
|
|
||||||
# Ensure the file is saved before we try to load it
|
|
||||||
manager.wait_for_save(file_path)
|
|
||||||
|
|
||||||
hidden_states = manager.load_tensor(file_path)
|
|
||||||
hidden_states.requires_grad = True
|
|
||||||
|
|
||||||
# Compute gradients
|
|
||||||
with torch.enable_grad():
|
|
||||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
|
||||||
|
|
||||||
# Handle tuple outputs properly
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
if len(grad_outputs) == len(output):
|
|
||||||
torch.autograd.backward(output, grad_outputs)
|
|
||||||
else:
|
|
||||||
torch.autograd.backward(output, grad_outputs[0])
|
|
||||||
else:
|
|
||||||
torch.autograd.backward(output, grad_outputs[0])
|
|
||||||
|
|
||||||
# Clean up the file after we're done with it
|
|
||||||
manager.cleanup_tensor(file_path)
|
|
||||||
|
|
||||||
return (
|
|
||||||
(
|
|
||||||
None, # forward_function
|
|
||||||
hidden_states.grad, # hidden_states grad
|
|
||||||
)
|
|
||||||
+ (None,) * len(ctx.args) # for each arg
|
|
||||||
+ (
|
|
||||||
None, # prefetch_size
|
|
||||||
None, # prefetch_to_gpu
|
|
||||||
None, # save_workers
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in backward pass: {e}")
|
|
||||||
# Clean up the file even on error
|
|
||||||
manager.cleanup_tensor(file_path)
|
|
||||||
raise
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""CPU offloaded checkpointing"""
|
"""Unsloth checkpointing"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -26,7 +26,7 @@ else:
|
|||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
torch.autograd.Function
|
torch.autograd.Function
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -70,10 +70,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import (
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
hf_grad_checkpoint_disk_offload_wrapper,
|
|
||||||
hf_grad_checkpoint_offload_wrapper,
|
|
||||||
)
|
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -623,10 +620,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
|
||||||
transformers.modeling_utils.checkpoint = (
|
|
||||||
hf_grad_checkpoint_disk_offload_wrapper
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ class AxolotlInputConfig(
|
|||||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
shrink_embeddings: bool | None = None
|
shrink_embeddings: bool | None = None
|
||||||
embeddings_skip_upcast: bool | None = None
|
embeddings_skip_upcast: bool | None = None
|
||||||
|
random_init_weights: bool | None = None
|
||||||
|
|
||||||
rl: RLType | None = None
|
rl: RLType | None = None
|
||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
@@ -178,7 +179,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: torch.dtype | None
|
# torch_dtype: torch.dtype | None
|
||||||
|
|
||||||
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
|
|||||||
@@ -26,15 +26,10 @@ class TestActivationCheckpointing:
|
|||||||
E2E tests for activation checkpointing
|
E2E tests for activation checkpointing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"gradient_checkpointing",
|
|
||||||
["offload", "offload_disk"],
|
|
||||||
)
|
|
||||||
def test_activation_checkpointing_offload(
|
def test_activation_checkpointing_offload(
|
||||||
self,
|
self,
|
||||||
temp_dir,
|
temp_dir,
|
||||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||||
gradient_checkpointing,
|
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -69,7 +64,7 @@ class TestActivationCheckpointing:
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"gradient_checkpointing": gradient_checkpointing,
|
"gradient_checkpointing": "offload",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user