Merge branch 'main' into fix/orpo_feature_parity
This commit is contained in:
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -31,6 +31,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -94,6 +99,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -295,6 +295,7 @@ jobs:
|
|||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
@@ -341,6 +342,8 @@ jobs:
|
|||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
|
# Only run the remainder of the matrix if the first e2e check passed;
|
||||||
|
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -365,6 +368,12 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ quartodoc:
|
|||||||
- core.trainers.mixins.optimizer
|
- core.trainers.mixins.optimizer
|
||||||
- core.trainers.mixins.rng_state_loader
|
- core.trainers.mixins.rng_state_loader
|
||||||
- core.trainers.mixins.scheduler
|
- core.trainers.mixins.scheduler
|
||||||
- core.trainers.mixins.sequence_parallel
|
|
||||||
- title: Context Managers
|
- title: Context Managers
|
||||||
desc: Context managers for altering trainer behaviors
|
desc: Context managers for altering trainer behaviors
|
||||||
contents:
|
contents:
|
||||||
@@ -141,7 +140,8 @@ quartodoc:
|
|||||||
- utils.optimizers.adopt
|
- utils.optimizers.adopt
|
||||||
- utils.data.pretraining
|
- utils.data.pretraining
|
||||||
- utils.data.sft
|
- utils.data.sft
|
||||||
- utils.gradient_checkpointing.unsloth
|
- utils.gradient_checkpointing.offload_cpu
|
||||||
|
- utils.gradient_checkpointing.offload_disk
|
||||||
- title: Schemas
|
- title: Schemas
|
||||||
desc: Pydantic data models for Axolotl config
|
desc: Pydantic data models for Axolotl config
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60,
|
timeout=90 * 60,
|
||||||
cpu=8.0,
|
cpu=16.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -540,7 +540,7 @@ train_on_inputs: false
|
|||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# Note that training loss may have an oscillating pattern with this enabled.
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
@@ -634,7 +634,9 @@ weight_decay:
|
|||||||
# adamw hyperparams
|
# adamw hyperparams
|
||||||
adam_beta1:
|
adam_beta1:
|
||||||
adam_beta2:
|
adam_beta2:
|
||||||
|
adam_beta3: # only used for CAME Optimizer
|
||||||
adam_epsilon:
|
adam_epsilon:
|
||||||
|
adam_epsilon2: # only used for CAME Optimizer
|
||||||
# Gradient clipping max norm
|
# Gradient clipping max norm
|
||||||
max_grad_norm:
|
max_grad_norm:
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
|||||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||||
format them.
|
format them.
|
||||||
|
|
||||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||||
format):
|
format):
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -120,6 +120,12 @@ axolotl train my_training.yml
|
|||||||
|
|
||||||
## Common Tasks {#sec-common-tasks}
|
## Common Tasks {#sec-common-tasks}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
The same yaml file is used for training, inference, and merging.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
### Testing Your Model {#sec-testing}
|
### Testing Your Model {#sec-testing}
|
||||||
|
|
||||||
After training, test your model:
|
After training, test your model:
|
||||||
@@ -128,6 +134,16 @@ After training, test your model:
|
|||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
More details can be found in [Inference](inference.qmd).
|
||||||
|
|
||||||
|
### Using a UI {#sec-ui}
|
||||||
|
|
||||||
|
Launch a Gradio interface:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||||
|
```
|
||||||
|
|
||||||
### Preprocessing Data {#sec-preprocessing}
|
### Preprocessing Data {#sec-preprocessing}
|
||||||
|
|
||||||
For large datasets, preprocess first:
|
For large datasets, preprocess first:
|
||||||
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
|
|||||||
axolotl preprocess my_training.yml
|
axolotl preprocess my_training.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using a UI {#sec-ui}
|
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||||
|
|
||||||
Launch a Gradio interface:
|
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||||
|
|
||||||
|
### Merging LoRA weights {#sec-merging-lora}
|
||||||
|
|
||||||
|
To merge the LoRA weights back into the base model, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||||
|
|
||||||
|
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||||
|
|
||||||
## Next Steps {#sec-next-steps}
|
## Next Steps {#sec-next-steps}
|
||||||
|
|
||||||
Now that you have the basics, you might want to:
|
Now that you have the basics, you might want to:
|
||||||
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
|
|||||||
Check our other guides for details on these topics:
|
Check our other guides for details on these topics:
|
||||||
|
|
||||||
- [Configuration Guide](config.qmd) - Full configuration options
|
- [Configuration Guide](config.qmd) - Full configuration options
|
||||||
|
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
|
||||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||||
- [Multi-GPU Training](multi-gpu.qmd)
|
- [Multi-GPU Training](multi-gpu.qmd)
|
||||||
- [Multi-Node Training](multi-node.qmd)
|
- [Multi-Node Training](multi-node.qmd)
|
||||||
|
|||||||
@@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the
|
|||||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||||
single sequence causes OOM errors during model training.
|
single sequence causes OOM errors during model training.
|
||||||
|
|
||||||
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||||
or from source with `pip install .[ring-flash-attn]`.
|
|
||||||
|
|
||||||
Your Axolotl YAML config should contain the following lines:
|
|
||||||
|
|
||||||
```{.yaml}
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
|
||||||
flash_attention: true # Required with sequence parallelism
|
|
||||||
|
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
|
||||||
heads_k_stride: 1
|
|
||||||
```
|
|
||||||
|
|
||||||
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
|
||||||
|
|
||||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ When sequence parallelism is enabled:
|
|||||||
|
|
||||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
3. Position IDs are adjusted to maintain proper relative positions
|
||||||
4. The trainer uses special ring communication patterns for attention operations
|
4. The trainer uses special ring communication patterns for attention operations
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
@@ -67,9 +67,11 @@ sequence_len: 8192
|
|||||||
...
|
...
|
||||||
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
flash_attention: true # Required with sequence parallelism
|
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||||
|
ring_attn_func:
|
||||||
|
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B
|
|||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
|
||||||
chat_template: qwen_25
|
chat_template: qwen_25
|
||||||
rl: dpo
|
rl: dpo
|
||||||
datasets:
|
datasets:
|
||||||
|
|||||||
@@ -465,8 +465,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"save_only_model",
|
"save_only_model",
|
||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
"weight_decay",
|
"weight_decay",
|
||||||
"sequence_parallel_degree",
|
|
||||||
"ring_attn_func",
|
|
||||||
"seed",
|
"seed",
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
|
|||||||
@@ -185,7 +185,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
if self.cfg.rl is not RLType.GRPO:
|
||||||
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import (
|
|||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
)
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
@@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
|
||||||
):
|
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
@@ -68,10 +65,6 @@ class AxolotlTrainer(
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
# Initialize sequence parallelism if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._setup_sequence_parallel()
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -122,8 +115,8 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
Helper method to get the sampler for training. Handles cases for sample packing
|
||||||
parallelism, sample packing, and curriculum sampling (sequential).
|
and curriculum sampling (sequential).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
@@ -132,9 +125,7 @@ class AxolotlTrainer(
|
|||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
|
||||||
elif self.args.curriculum_sampling:
|
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
@@ -153,8 +144,7 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
Helper method to get the sampler for evaluation. Handles sample packing case.
|
||||||
and sample packing cases.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
@@ -168,9 +158,7 @@ class AxolotlTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Determine the base sampler
|
# Determine the base sampler
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if use_multipack:
|
||||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
|
||||||
elif use_multipack:
|
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
@@ -236,14 +224,6 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
|
||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
|
||||||
# slice each batch along the sequence dimension).
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
@@ -287,12 +267,7 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Handle sample packing or sequence parallelism
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
if (
|
|
||||||
self.args.sample_packing
|
|
||||||
and self.args.eval_sample_packing is not False
|
|
||||||
or self.args.sequence_parallel_degree > 1
|
|
||||||
):
|
|
||||||
# Get appropriate data collator
|
# Get appropriate data collator
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
@@ -302,17 +277,6 @@ class AxolotlTrainer(
|
|||||||
if "length" in eval_dataset.column_names:
|
if "length" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
# Handle dataset preprocessing for SP
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
if isinstance(eval_dataset, datasets.Dataset):
|
|
||||||
eval_dataset = self._remove_unused_columns(
|
|
||||||
eval_dataset, description="evaluation"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.data_collator, description="evaluation"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
batch_size = (
|
batch_size = (
|
||||||
self.args.eval_batch_size
|
self.args.eval_batch_size
|
||||||
|
|||||||
@@ -1,31 +1,15 @@
|
|||||||
"""
|
"""DPO trainer for axolotl"""
|
||||||
DPO trainer for axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import random
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
|
||||||
from accelerate import PartialState
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from transformers import Trainer
|
||||||
from transformers import (
|
|
||||||
BaseImageProcessor,
|
|
||||||
FeatureExtractionMixin,
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
ProcessorMixin,
|
|
||||||
Trainer,
|
|
||||||
)
|
|
||||||
from transformers.trainer_utils import EvalLoopOutput
|
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import log_table_to_comet_experiment
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
@@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled():
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||||
"""
|
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
@@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
@wraps(DPOTrainer.push_to_hub)
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`
|
||||||
|
for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
@@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
|
||||||
def _prepare_dataset(
|
|
||||||
self,
|
|
||||||
dataset: Union[Dataset, IterableDataset],
|
|
||||||
processing_class: Union[
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
BaseImageProcessor,
|
|
||||||
FeatureExtractionMixin,
|
|
||||||
ProcessorMixin,
|
|
||||||
],
|
|
||||||
args: DPOConfig,
|
|
||||||
dataset_name: str,
|
|
||||||
) -> Union[Dataset, IterableDataset]:
|
|
||||||
# Build the kwargs for the `map` function
|
|
||||||
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
|
|
||||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
|
||||||
map_kwargs["num_proc"] = args.dataset_num_proc
|
|
||||||
|
|
||||||
with PartialState().main_process_first():
|
|
||||||
# Extract prompt if needed
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
|
||||||
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
|
||||||
|
|
||||||
# Apply the chat template if needed
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
|
||||||
dataset = dataset.map(
|
|
||||||
maybe_apply_chat_template,
|
|
||||||
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
|
|
||||||
**map_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize the dataset
|
|
||||||
if isinstance(
|
|
||||||
dataset, Dataset
|
|
||||||
): # `IterableDataset.map` does not support `desc`
|
|
||||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
self.tokenize_row if not self.is_vision_model else self.process_row,
|
|
||||||
remove_columns=["chosen", "rejected"],
|
|
||||||
fn_kwargs={
|
|
||||||
"processing_class": processing_class,
|
|
||||||
"max_prompt_length": args.max_prompt_length,
|
|
||||||
"max_completion_length": args.max_completion_length,
|
|
||||||
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
|
||||||
"add_special_tokens": False,
|
|
||||||
},
|
|
||||||
**map_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
@@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
|
||||||
def evaluation_loop(
|
|
||||||
self,
|
|
||||||
dataloader: DataLoader,
|
|
||||||
description: str,
|
|
||||||
prediction_loss_only: Optional[bool] = None,
|
|
||||||
ignore_keys: Optional[list[str]] = None,
|
|
||||||
metric_key_prefix: str = "eval",
|
|
||||||
) -> EvalLoopOutput:
|
|
||||||
"""
|
|
||||||
Overriding built-in evaluation loop to store metrics for each batch.
|
|
||||||
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
|
||||||
|
|
||||||
Works both with or without labels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Sample and save to game log if requested (for one batch to save time)
|
|
||||||
if self.generate_during_eval:
|
|
||||||
# Generate random indices within the range of the total number of samples
|
|
||||||
num_samples = len(dataloader.dataset)
|
|
||||||
random_indices = random.sample(
|
|
||||||
range(num_samples), k=self.args.eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
|
||||||
random_batch_dataset = dataloader.dataset.select(random_indices)
|
|
||||||
random_batch = self.data_collator(random_batch_dataset)
|
|
||||||
random_batch = self._prepare_inputs(random_batch)
|
|
||||||
|
|
||||||
policy_output_decoded, ref_output_decoded = (
|
|
||||||
self.generate_from_model_and_ref(self.model, random_batch)
|
|
||||||
)
|
|
||||||
|
|
||||||
table = pd.DataFrame(
|
|
||||||
columns=["Prompt", "Policy", "Ref Model"],
|
|
||||||
data=[
|
|
||||||
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
|
||||||
for prompt, pol, ref in zip(
|
|
||||||
random_batch_dataset["prompt"],
|
|
||||||
policy_output_decoded,
|
|
||||||
ref_output_decoded,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
|
||||||
wandb.log({"game_log": wandb.Table(data=table)})
|
|
||||||
|
|
||||||
if "comet_ml" in self.args.report_to:
|
|
||||||
log_table_to_comet_experiment(
|
|
||||||
name="game_log.csv",
|
|
||||||
table=table,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base evaluation
|
|
||||||
initial_output = super( # pylint: disable=bad-super-call
|
|
||||||
DPOTrainer, self
|
|
||||||
).evaluation_loop(
|
|
||||||
dataloader,
|
|
||||||
description,
|
|
||||||
prediction_loss_only,
|
|
||||||
ignore_keys,
|
|
||||||
metric_key_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
return initial_output
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -14,7 +13,7 @@ from accelerate.utils import (
|
|||||||
broadcast_object_list,
|
broadcast_object_list,
|
||||||
gather,
|
gather,
|
||||||
gather_object,
|
gather_object,
|
||||||
is_peft_model,
|
is_peft_available,
|
||||||
)
|
)
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -30,15 +29,13 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_peft_available
|
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.data_utils import (
|
from trl.data_utils import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
is_conversational,
|
is_conversational,
|
||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
from trl.extras.profiling import profiling_context
|
||||||
from trl.import_utils import is_deepspeed_available
|
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
from trl.trainer.grpo_config import GRPOConfig
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||||
@@ -46,68 +43,18 @@ from trl.trainer.utils import pad
|
|||||||
|
|
||||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
|
from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from peft import PeftConfig
|
from peft import PeftConfig
|
||||||
|
|
||||||
if is_deepspeed_available():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
@profiling_decorator
|
|
||||||
def _move_model_to_vllm(self):
|
|
||||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
|
||||||
gather_if_zero3 = (
|
|
||||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_peft_model(self.model):
|
|
||||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
|
||||||
# adapters in a sharded manner is not supported.
|
|
||||||
with gather_if_zero3(list(self.model.parameters())):
|
|
||||||
self.model.merge_adapter()
|
|
||||||
|
|
||||||
# Update vLLM weights while parameters are gathered
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
|
||||||
name = (
|
|
||||||
name.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", "")
|
|
||||||
)
|
|
||||||
if self.model.prefix in name:
|
|
||||||
continue
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
if "original_module" in name:
|
|
||||||
continue
|
|
||||||
name = name.replace("modules_to_save.default.", "")
|
|
||||||
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Unmerge adapters while parameters are still gathered
|
|
||||||
self.model.unmerge_adapter()
|
|
||||||
# Parameters will automatically be repartitioned when exiting the context
|
|
||||||
else:
|
|
||||||
# For non-PEFT models, simply gather and update each parameter individually.
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
with gather_if_zero3([param]):
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Reset cache on main process
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.reset_prefix_cache()
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|||||||
@@ -6,4 +6,3 @@
|
|||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from datasets import Dataset
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
|
||||||
get_ring_attn_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
|
||||||
"""
|
|
||||||
Mixin class for sequence parallelism support in trainers.
|
|
||||||
|
|
||||||
This mixin provides functionality for handling sequence parallelism,
|
|
||||||
specifically for creating appropriate data samplers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def _setup_sequence_parallel(self):
|
|
||||||
"""Set up sequence parallelism environment."""
|
|
||||||
self.ring_attn_group = get_ring_attn_group()
|
|
||||||
|
|
||||||
def _create_sequence_parallel_sampler(
|
|
||||||
self,
|
|
||||||
dataset: Dataset,
|
|
||||||
shuffle: bool = True,
|
|
||||||
is_eval: bool = False,
|
|
||||||
) -> DistributedSampler:
|
|
||||||
"""
|
|
||||||
Helper method to create sampler for sequence parallelism (SP).
|
|
||||||
|
|
||||||
We create a distributed sampler with rank equal to the SP group ID, which
|
|
||||||
means that all ranks in the SP group receive the same sample / set of samples
|
|
||||||
per training step. We also set the number of replicas equal to the number of
|
|
||||||
SP groups, which is a bit of a hack / unintended use, but works!
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Dataset to sample from.
|
|
||||||
shuffle: Whether to shuffle the dataset.
|
|
||||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Distributed sampler.
|
|
||||||
"""
|
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return DistributedSampler(
|
|
||||||
dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
seed=self.args.seed if shuffle else None,
|
|
||||||
shuffle=shuffle,
|
|
||||||
drop_last=not is_eval,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get a training sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: The training dataset
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
dataset,
|
|
||||||
shuffle=not self.args.curriculum_sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: The evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
|
||||||
)
|
|
||||||
@@ -9,8 +9,6 @@ from PIL.Image import Resampling
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
@@ -216,14 +214,16 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallel_degree: Optional[int] = field(
|
adam_beta3: Optional[float] = field(
|
||||||
default=1,
|
|
||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
|
||||||
)
|
|
||||||
ring_attn_func: Optional[RingAttnFunc] = field(
|
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The ring-flash-attn function to use in sequence parallelism"
|
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adam_epsilon2: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
COHERE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -20,15 +20,11 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache, HybridCache
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -170,10 +162,6 @@ def cce_forward(
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
|
|||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
LLAMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -16,22 +16,12 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama4.modeling_llama4 import (
|
from transformers.models.llama4.modeling_llama4 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
LLAMA4_INPUTS_DOCSTRING,
|
|
||||||
Llama4CausalLMOutputWithPast,
|
Llama4CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -160,9 +150,6 @@ def cce_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||||
|
|||||||
@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
|
|||||||
Mistral3CausalLMOutputWithPast,
|
Mistral3CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
MISTRAL_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN2MOE_INPUTS_DOCSTRING,
|
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
)
|
)
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN2_VL_INPUTS_DOCSTRING,
|
|
||||||
Qwen2VLCausalLMOutputWithPast,
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
TransformersModelT,
|
TransformersModelT,
|
||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN3_MOE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
|
||||||
# @replace_return_docstrings(
|
|
||||||
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
# )
|
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
from transformers.models.jamba.modeling_jamba import (
|
from transformers.models.jamba.modeling_jamba import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
JAMBA_INPUTS_DOCSTRING,
|
|
||||||
HybridMambaAttentionDynamicCache,
|
HybridMambaAttentionDynamicCache,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
"""Init for ring attention monkeypatch module"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .patch import (
|
|
||||||
get_ring_attn_group,
|
|
||||||
register_ring_attn,
|
|
||||||
set_ring_attn_group,
|
|
||||||
update_ring_attn_params,
|
|
||||||
)
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""
|
|
||||||
Ring attention group registration and flash attention patching.
|
|
||||||
|
|
||||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
|
||||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
|
||||||
their sequence parallel version of Flash Attention 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
|
||||||
"""
|
|
||||||
Getter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The process group for ring attention for this rank.
|
|
||||||
"""
|
|
||||||
return RING_ATTN_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|
||||||
"""
|
|
||||||
Setter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Process group for ring attention.
|
|
||||||
"""
|
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(
|
|
||||||
sequence_parallel_degree: int,
|
|
||||||
heads_k_stride: int | None,
|
|
||||||
ring_attn_func: RingAttnFunc | None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
|
||||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
|
||||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
|
||||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
|
||||||
`batch` function.
|
|
||||||
"""
|
|
||||||
if get_ring_attn_group() is not None:
|
|
||||||
LOG.info("Ring attention already registered, exiting early...")
|
|
||||||
return
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Enabling ring attention sequence parallelism: "
|
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
|
||||||
)
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
|
||||||
)
|
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must evenly divide world_size ({world_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign ranks to sequence parallel groups
|
|
||||||
group_assignments = {}
|
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
|
||||||
ring_attn_ranks = list(
|
|
||||||
range(
|
|
||||||
i * sequence_parallel_degree,
|
|
||||||
(i + 1) * sequence_parallel_degree,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
|
||||||
|
|
||||||
# Track which GPUs are in which groups
|
|
||||||
for r in ring_attn_ranks:
|
|
||||||
group_assignments[r] = i
|
|
||||||
|
|
||||||
if rank in ring_attn_ranks:
|
|
||||||
set_ring_attn_group(group)
|
|
||||||
|
|
||||||
# Log the GPU group assignments
|
|
||||||
if rank == 0:
|
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
|
||||||
|
|
||||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
|
||||||
|
|
||||||
substitute_hf_flash_attn(
|
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
|
||||||
)
|
|
||||||
elif ring_attn_func is RingAttnFunc.BATCH_RING:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
|
||||||
substitute_hf_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
substitute_hf_flash_attn(
|
|
||||||
process_group=get_ring_attn_group(),
|
|
||||||
ring_attn_func=ring_attn_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
|
||||||
"""
|
|
||||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
|
||||||
value to the substituted `ring_flash_attn`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
position_ids: Optional tensor of position IDs (for sample packed data).
|
|
||||||
"""
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
|
|
||||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
|
||||||
@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def new_forward(
|
def new_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
22
src/axolotl/monkeypatch/ring_attn/__init__.py
Normal file
22
src/axolotl/monkeypatch/ring_attn/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Init for ring attention monkeypatch module"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .patch import (
|
||||||
|
get_ring_attn_group,
|
||||||
|
patch_prepare_data_loader,
|
||||||
|
patch_prepare_device_mesh,
|
||||||
|
register_ring_attn,
|
||||||
|
set_ring_attn_group,
|
||||||
|
update_ring_attn_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"get_ring_attn_group",
|
||||||
|
"patch_prepare_data_loader",
|
||||||
|
"patch_prepare_device_mesh",
|
||||||
|
"register_ring_attn",
|
||||||
|
"set_ring_attn_group",
|
||||||
|
"update_ring_attn_params",
|
||||||
|
)
|
||||||
223
src/axolotl/monkeypatch/ring_attn/patch.py
Normal file
223
src/axolotl/monkeypatch/ring_attn/patch.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""Ring attention group registration and flash attention patching.
|
||||||
|
|
||||||
|
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||||
|
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||||
|
their sequence parallel version of Flash Attention 2.
|
||||||
|
|
||||||
|
We also provide some patches for accelerate functions to prepare the dataloader for
|
||||||
|
sequence parallelism training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||||
|
submesh_dp_size = 1
|
||||||
|
submesh_tp_size = 1
|
||||||
|
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||||
|
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||||
|
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||||
|
process_index = process_index // submesh_tp_size"""
|
||||||
|
|
||||||
|
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||||
|
submesh_dp_size = 1
|
||||||
|
submesh_tp_size = 1
|
||||||
|
submesh_cp_size = 1
|
||||||
|
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||||
|
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||||
|
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||||
|
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||||
|
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||||
|
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
|
"""Getter for ring attention group on this rank."""
|
||||||
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||||
|
"""Setter for ring attention group on this rank."""
|
||||||
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(
|
||||||
|
sequence_parallel_degree: int,
|
||||||
|
heads_k_stride: int | None,
|
||||||
|
ring_attn_func: RingAttnFunc | None,
|
||||||
|
):
|
||||||
|
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||||
|
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||||
|
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||||
|
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||||
|
`batch` function.
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
LOG.info(
|
||||||
|
"Enabling ring attention sequence parallelism: "
|
||||||
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sequence_parallel_degree <= world_size, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
|
)
|
||||||
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign ranks to sequence parallel groups
|
||||||
|
group_assignments = {}
|
||||||
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
|
ring_attn_ranks = list(
|
||||||
|
range(
|
||||||
|
i * sequence_parallel_degree,
|
||||||
|
(i + 1) * sequence_parallel_degree,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
|
|
||||||
|
# Track which GPUs are in which groups
|
||||||
|
for r in ring_attn_ranks:
|
||||||
|
group_assignments[r] = i
|
||||||
|
|
||||||
|
if rank in ring_attn_ranks:
|
||||||
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
|
# Log the GPU group assignments
|
||||||
|
if rank == 0:
|
||||||
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(
|
||||||
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||||
|
)
|
||||||
|
elif ring_attn_func is RingAttnFunc.BATCH_RING:
|
||||||
|
from axolotl.monkeypatch.ring_attn.adapters.batch import (
|
||||||
|
substitute_hf_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(
|
||||||
|
process_group=get_ring_attn_group(),
|
||||||
|
ring_attn_func=ring_attn_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||||
|
"""
|
||||||
|
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||||
|
value to the substituted `ring_flash_attn`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position_ids: Optional tensor of position IDs (for sample packed data).
|
||||||
|
"""
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
|
||||||
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_data_loader():
|
||||||
|
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||||
|
|
||||||
|
Raies:
|
||||||
|
RuntimeError: If source code to patch does not exist.
|
||||||
|
"""
|
||||||
|
original_fn = accelerate.data_loader.prepare_data_loader
|
||||||
|
original_source = inspect.getsource(original_fn)
|
||||||
|
|
||||||
|
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SP patch failed - target snippet not found. "
|
||||||
|
"Check accelerate's version or update the patch."
|
||||||
|
)
|
||||||
|
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new function from the patched source
|
||||||
|
namespace = {}
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
patched_source, accelerate.data_loader.__dict__, namespace
|
||||||
|
)
|
||||||
|
patched_function = namespace["prepare_data_loader"]
|
||||||
|
|
||||||
|
accelerate.data_loader.prepare_data_loader = patched_function
|
||||||
|
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||||
|
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||||
|
that includes sequence parallelism with the specified degree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_device_mesh(self):
|
||||||
|
"""Prepare the device mesh for distributed training. The dataloader will
|
||||||
|
determine how to load data based on the device mesh.
|
||||||
|
"""
|
||||||
|
if self.state.torch_tp_plugin:
|
||||||
|
return self.state.torch_tp_plugin.torch_device_mesh
|
||||||
|
if (
|
||||||
|
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
|
||||||
|
and hasattr(self.state, "ds_device_mesh")
|
||||||
|
):
|
||||||
|
return self.state.ds_device_mesh
|
||||||
|
|
||||||
|
# Create device mesh with sequence parallelism
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
mesh_shape = (
|
||||||
|
world_size // sequence_parallel_degree,
|
||||||
|
sequence_parallel_degree,
|
||||||
|
)
|
||||||
|
device_ids = list(range(world_size))
|
||||||
|
|
||||||
|
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
|
||||||
|
# parallelism" implementation naming
|
||||||
|
return dist.DeviceMesh(
|
||||||
|
"cuda",
|
||||||
|
torch.tensor(device_ids).reshape(mesh_shape),
|
||||||
|
mesh_dim_names=("dp", "cp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the original method with our new method
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Successfully patched Accelerator._prepare_device_mesh "
|
||||||
|
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||||
|
)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
|
|||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_artifacts() -> bool:
|
||||||
|
truths = ["TRUE", "1", "YES"]
|
||||||
|
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
"""Callback to save axolotl config to mlflow"""
|
"""Callback to save axolotl config to mlflow"""
|
||||||
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
with NamedTemporaryFile(
|
if should_log_artifacts():
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
with NamedTemporaryFile(
|
||||||
) as temp_file:
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
) as temp_file:
|
||||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the MLflow artifacts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"The Axolotl config has been saved to the MLflow artifacts."
|
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -9,7 +10,7 @@ from torch.utils.hooks import RemovableHandle
|
|||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
@@ -206,12 +207,25 @@ class SequenceParallelContextManager:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
# Apply sequence parallelism to kwargs and get original sequence length and padding info
|
# Get parameter names from the model's forward function
|
||||||
kwargs, self.original_seq_len, self.pad_len = (
|
forward_params = list(
|
||||||
self.apply_sequence_parallelism(batch=kwargs)
|
inspect.signature(self.models[0].forward).parameters.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
return args, kwargs
|
updated_kwargs = kwargs.copy()
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i < len(forward_params):
|
||||||
|
updated_kwargs[forward_params[i]] = arg
|
||||||
|
|
||||||
|
# Any excess positional arguments are kept as-is
|
||||||
|
remaining_args = args[len(forward_params) :]
|
||||||
|
|
||||||
|
# Apply sequence parallelism to updated kwargs
|
||||||
|
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
self.apply_sequence_parallelism(updated_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return remaining_args, updated_kwargs
|
||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -484,7 +484,7 @@ def get_dataset_wrapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
|
)
|
||||||
|
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||||
|
Disco,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if uses_gc_layers(decoder_layer):
|
if uses_gc_layers(decoder_layer):
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer,
|
decoder_layer,
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
(
|
||||||
|
decoder_layer.func.__self__
|
||||||
|
if isinstance(decoder_layer, partial)
|
||||||
|
else decoder_layer.__self__
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
if uses_gc_layers(decoder_layer):
|
||||||
|
return Disco.apply(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Disco.apply(
|
||||||
(
|
(
|
||||||
decoder_layer.func.__self__
|
decoder_layer.func.__self__
|
||||||
if isinstance(decoder_layer, partial)
|
if isinstance(decoder_layer, partial)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Unsloth checkpointing"""
|
"""CPU offloaded checkpointing"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -26,7 +26,7 @@ else:
|
|||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
torch.autograd.Function
|
torch.autograd.Function
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
"""
|
||||||
|
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
# Setup logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskOffloadManager:
|
||||||
|
"""
|
||||||
|
Manages offloaded tensors and handles prefetching in a separate thread.
|
||||||
|
Includes synchronization to prevent race conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefetch_size: int = 3,
|
||||||
|
prefetch_to_gpu: bool = True,
|
||||||
|
save_workers: int = 4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prefetch_size: Maximum number of tensors to prefetch in the background.
|
||||||
|
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
||||||
|
save_workers: Maximum number of concurrent save operations.
|
||||||
|
"""
|
||||||
|
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
||||||
|
|
||||||
|
# Track tensor paths and their status
|
||||||
|
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
||||||
|
self.file_locks: Dict[str, threading.Lock] = (
|
||||||
|
{}
|
||||||
|
) # Maps file_path -> threading.Lock()
|
||||||
|
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
||||||
|
self.file_status: Dict[str, str] = {}
|
||||||
|
|
||||||
|
self.max_prefetch = prefetch_size
|
||||||
|
self.prefetch_to_gpu = prefetch_to_gpu
|
||||||
|
|
||||||
|
# Thread synchronization
|
||||||
|
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
||||||
|
|
||||||
|
# Prefetch queue and cache
|
||||||
|
self.prefetch_queue: queue.Queue = queue.Queue()
|
||||||
|
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
||||||
|
|
||||||
|
# Save queue and thread pool
|
||||||
|
self.save_queue: queue.Queue = queue.Queue()
|
||||||
|
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
||||||
|
self.save_futures: Dict[str, Future] = {}
|
||||||
|
self.save_semaphore = threading.Semaphore(
|
||||||
|
save_workers * 2
|
||||||
|
) # Limit concurrent save operations
|
||||||
|
|
||||||
|
# Start prefetch worker thread
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
# start multiple threads for prefetching
|
||||||
|
self.prefetch_worker_count = 2
|
||||||
|
self.prefetch_workers = []
|
||||||
|
for _ in range(self.prefetch_worker_count):
|
||||||
|
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
||||||
|
worker.start()
|
||||||
|
self.prefetch_workers.append(worker)
|
||||||
|
|
||||||
|
# Start save worker thread
|
||||||
|
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
||||||
|
self.save_worker.start()
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
atexit.register(self.cleanup)
|
||||||
|
|
||||||
|
def _save_worker(self):
|
||||||
|
"""Background thread that processes the save queue"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
save_item = self.save_queue.get(timeout=0.5)
|
||||||
|
if save_item is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor, file_path = save_item
|
||||||
|
|
||||||
|
# Submit the save task to the thread pool
|
||||||
|
future = self.save_pool.submit(
|
||||||
|
self._save_tensor_to_disk, tensor, file_path
|
||||||
|
)
|
||||||
|
with self.manager_lock:
|
||||||
|
self.save_futures[file_path] = future
|
||||||
|
|
||||||
|
self.save_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
||||||
|
"""Actually save the tensor to disk"""
|
||||||
|
try:
|
||||||
|
# Save tensor to disk
|
||||||
|
cpu_tensor = tensor.detach().cpu()
|
||||||
|
torch.save(cpu_tensor, file_path)
|
||||||
|
del cpu_tensor
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as ready
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"Error saving tensor to {file_path}: {e}")
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _prefetch_worker(self):
|
||||||
|
"""Background thread that loads tensors from disk ahead of time"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
file_path = self.prefetch_queue.get(timeout=0.5)
|
||||||
|
if file_path is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if file is available and not already in cache
|
||||||
|
with self.manager_lock:
|
||||||
|
if (
|
||||||
|
file_path not in self.file_status
|
||||||
|
or self.file_status[file_path] == "deleted"
|
||||||
|
):
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If file is still being saved, wait for it
|
||||||
|
if (
|
||||||
|
self.file_status[file_path] == "saving"
|
||||||
|
and file_path in self.save_futures
|
||||||
|
):
|
||||||
|
# Re-queue this prefetch request with a little delay
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.prefetch_queue.put(file_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Mark file as being prefetched
|
||||||
|
self.file_status[file_path] = "prefetching"
|
||||||
|
|
||||||
|
# Load tensor from disk and store in cache
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
if self.prefetch_to_gpu:
|
||||||
|
tensor = torch.load(
|
||||||
|
file_path,
|
||||||
|
map_location=torch.device("cuda"),
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache[file_path] = tensor
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
else:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch error: File not found {file_path}"
|
||||||
|
)
|
||||||
|
self.file_status[file_path] = "missing"
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(f"Prefetch error for {file_path}: {e}")
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def save_tensor(self, tensor: torch.Tensor):
|
||||||
|
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
||||||
|
# Generate unique file path
|
||||||
|
self.idx += 1
|
||||||
|
file_path: str = os.path.join(
|
||||||
|
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as being saved
|
||||||
|
self.file_locks[file_path] = threading.Lock()
|
||||||
|
self.file_status[file_path] = "saving"
|
||||||
|
# Add to history
|
||||||
|
self.tensor_paths.append(file_path)
|
||||||
|
|
||||||
|
# Acquire semaphore to limit concurrent save operations
|
||||||
|
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
||||||
|
# Queue tensor for saving in background
|
||||||
|
self.save_queue.put((tensor.detach(), file_path))
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def wait_for_save(self, file_path, timeout=None) -> None:
|
||||||
|
"""Wait for a tensor to be saved to disk"""
|
||||||
|
start_time = time.time()
|
||||||
|
while timeout is None or time.time() - start_time < timeout:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) == "ready":
|
||||||
|
return
|
||||||
|
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if future.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Small sleep to prevent CPU spinning
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Timeout
|
||||||
|
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_tensor(self, file_path, target_device="cuda"):
|
||||||
|
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
||||||
|
# Wait for tensor to be saved if it's still in progress
|
||||||
|
self.wait_for_save(file_path)
|
||||||
|
|
||||||
|
tensor = None
|
||||||
|
|
||||||
|
# Try to get from cache first
|
||||||
|
with self.manager_lock:
|
||||||
|
# Check if tensor is already in cache
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
tensor = self.prefetch_cache[file_path]
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if tensor is not None:
|
||||||
|
# Ensure tensor is on correct device
|
||||||
|
if target_device != "cpu" and tensor.device.type == "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# If not in cache, load directly from disk
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"File not found for loading: {file_path}")
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if target_device != "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading tensor from {file_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _safe_delete_file(self, file_path):
|
||||||
|
"""Safely delete a file with proper synchronization"""
|
||||||
|
with self.manager_lock:
|
||||||
|
# Make sure any save operation is completed
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
try:
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error canceling save operation for {file_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only delete if file exists and is not being prefetched
|
||||||
|
status = self.file_status.get(file_path)
|
||||||
|
if status in ["ready", "loaded", "error", "missing"]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
self.file_status[file_path] = "deleted"
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error deleting file {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trigger_prefetch(self, n=None):
|
||||||
|
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
||||||
|
if n is None:
|
||||||
|
n = self.max_prefetch
|
||||||
|
|
||||||
|
prefetch_paths = []
|
||||||
|
with self.manager_lock:
|
||||||
|
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
||||||
|
for path in reversed(self.tensor_paths):
|
||||||
|
if (
|
||||||
|
path not in self.prefetch_cache
|
||||||
|
and self.file_status.get(path) == "ready"
|
||||||
|
):
|
||||||
|
prefetch_paths.append(path)
|
||||||
|
if len(prefetch_paths) >= n:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Queue files for prefetching
|
||||||
|
for path in prefetch_paths:
|
||||||
|
self.prefetch_queue.put(path)
|
||||||
|
|
||||||
|
def cleanup_tensor(self, file_path: str):
|
||||||
|
"""Clean up a specific tensor file after it's been used"""
|
||||||
|
with self.manager_lock:
|
||||||
|
if file_path in self.tensor_paths:
|
||||||
|
self.tensor_paths.remove(file_path)
|
||||||
|
|
||||||
|
# Remove from prefetch cache if present
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
|
||||||
|
# Remove from save futures if present
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
|
||||||
|
# Try to delete the file
|
||||||
|
self._safe_delete_file(file_path)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
|
# Cancel all pending save operations
|
||||||
|
with self.manager_lock:
|
||||||
|
for _, future in self.save_futures.items():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self.save_futures.clear()
|
||||||
|
|
||||||
|
# Drain the save queue
|
||||||
|
while not self.save_queue.empty():
|
||||||
|
try:
|
||||||
|
self.save_queue.get_nowait()
|
||||||
|
self.save_queue.task_done()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Shutdown the save pool
|
||||||
|
self.save_pool.shutdown(wait=False)
|
||||||
|
|
||||||
|
# Join the save worker thread
|
||||||
|
if self.save_worker.is_alive():
|
||||||
|
self.save_worker.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Join the prefetch worker threads
|
||||||
|
for thread in self.prefetch_workers:
|
||||||
|
if thread.is_alive():
|
||||||
|
thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Clear cache and remove all temporary files
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache.clear()
|
||||||
|
paths_to_delete = list(self.tensor_paths)
|
||||||
|
self.tensor_paths.clear()
|
||||||
|
|
||||||
|
# Delete all temporary files
|
||||||
|
for path in paths_to_delete:
|
||||||
|
self._safe_delete_file(path)
|
||||||
|
|
||||||
|
# Remove temp directory
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.temp_dir):
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class Disco(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
Advanced disk-based gradient checkpointer with prefetching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Shared manager instance across all checkpointing operations
|
||||||
|
_manager = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
||||||
|
"""Get or create the offload manager"""
|
||||||
|
if Disco._manager is None:
|
||||||
|
Disco._manager = DiskOffloadManager(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
return Disco._manager
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
forward_function,
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
prefetch_size=1,
|
||||||
|
prefetch_to_gpu=True,
|
||||||
|
save_workers=4,
|
||||||
|
):
|
||||||
|
"""Forward pass that offloads activations to disk asynchronously"""
|
||||||
|
# Get or create the manager
|
||||||
|
manager = Disco.get_instance(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tensor to disk asynchronously
|
||||||
|
file_path = manager.save_tensor(hidden_states)
|
||||||
|
|
||||||
|
# Run forward pass immediately without waiting for save to complete
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
|
# Store what we need for backward
|
||||||
|
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||||
|
ctx.file_path = file_path
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_bwd
|
||||||
|
def backward(ctx, *grad_outputs):
|
||||||
|
"""Backward pass that loads activations from disk with prefetching"""
|
||||||
|
# Get the manager
|
||||||
|
manager = Disco._manager
|
||||||
|
|
||||||
|
# Trigger prefetching for future tensors
|
||||||
|
# This happens at the start of backward, so should have time to complete
|
||||||
|
manager.trigger_prefetch()
|
||||||
|
|
||||||
|
# Load hidden states from disk or prefetch cache
|
||||||
|
file_path = ctx.file_path
|
||||||
|
try:
|
||||||
|
# Ensure the file is saved before we try to load it
|
||||||
|
manager.wait_for_save(file_path)
|
||||||
|
|
||||||
|
hidden_states = manager.load_tensor(file_path)
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
|
# Compute gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
|
||||||
|
# Handle tuple outputs properly
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
if len(grad_outputs) == len(output):
|
||||||
|
torch.autograd.backward(output, grad_outputs)
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
|
||||||
|
# Clean up the file after we're done with it
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
None, # forward_function
|
||||||
|
hidden_states.grad, # hidden_states grad
|
||||||
|
)
|
||||||
|
+ (None,) * len(ctx.args) # for each arg
|
||||||
|
+ (
|
||||||
|
None, # prefetch_size
|
||||||
|
None, # prefetch_to_gpu
|
||||||
|
None, # save_workers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in backward pass: {e}")
|
||||||
|
# Clean up the file even on error
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
raise
|
||||||
@@ -59,6 +59,7 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
)
|
)
|
||||||
|
from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
@@ -70,7 +71,10 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
|
)
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -620,6 +624,10 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
@@ -674,16 +682,25 @@ class ModelLoader:
|
|||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
from axolotl.monkeypatch.ring_attn import (
|
||||||
|
patch_prepare_data_loader,
|
||||||
|
patch_prepare_device_mesh,
|
||||||
|
register_ring_attn,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
# model init but before the first forward pass, since it modifies flash
|
# model init but before the first forward pass, since it modifies flash
|
||||||
# attn to use ring comm for SP training across multiple GPUs.
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
register_ring_attn(
|
if get_ring_attn_group() is None: # If already set, this is already patched
|
||||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
register_ring_attn(
|
||||||
heads_k_stride=self.cfg.heads_k_stride,
|
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||||
ring_attn_func=self.cfg.ring_attn_func,
|
heads_k_stride=self.cfg.heads_k_stride,
|
||||||
)
|
ring_attn_func=self.cfg.ring_attn_func,
|
||||||
|
)
|
||||||
|
patch_prepare_data_loader()
|
||||||
|
patch_prepare_device_mesh(
|
||||||
|
sequence_parallel_degree=self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: torch.dtype | None
|
# torch_dtype: torch.dtype | None
|
||||||
|
|
||||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
|
|||||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
recursive_kill(vllm_process)
|
recursive_kill(vllm_process)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
|
|||||||
@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
|
|||||||
E2E tests for activation checkpointing
|
E2E tests for activation checkpointing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
["offload", "offload_disk"],
|
||||||
|
)
|
||||||
def test_activation_checkpointing_offload(
|
def test_activation_checkpointing_offload(
|
||||||
self,
|
self,
|
||||||
temp_dir,
|
temp_dir,
|
||||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
gradient_checkpointing,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"gradient_checkpointing": "offload",
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from accelerate.state import PartialState
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
register_ring_attn,
|
register_ring_attn,
|
||||||
set_ring_attn_group,
|
set_ring_attn_group,
|
||||||
@@ -313,13 +313,13 @@ class TestApplySequenceParallelism:
|
|||||||
|
|
||||||
# Mock the process group
|
# Mock the process group
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
|
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
||||||
MagicMock,
|
MagicMock,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock update_ring_attn_params
|
# Mock update_ring_attn_params
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
|
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
||||||
lambda **kwargs: None,
|
lambda **kwargs: None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user