Compare commits
10 Commits
wait-distr
...
fix/kd-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
348409c2ff | ||
|
|
a27b909c5c | ||
|
|
6cb07b9d12 | ||
|
|
288653adb6 | ||
|
|
3a5b495a74 | ||
|
|
f661858fc4 | ||
|
|
c837c4a424 | ||
|
|
c9797de6bb | ||
|
|
8f8a7afb05 | ||
|
|
86472715da |
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -31,6 +31,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -94,6 +99,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -295,6 +295,7 @@ jobs:
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
@@ -341,6 +342,8 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
@@ -365,6 +368,12 @@ jobs:
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=90 * 60,
|
||||
cpu=8.0,
|
||||
cpu=16.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
|
||||
@@ -633,7 +633,9 @@ weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_beta3: # only used for CAME Optimizer
|
||||
adam_epsilon:
|
||||
adam_epsilon2: # only used for CAME Optimizer
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||
format them.
|
||||
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||
format):
|
||||
|
||||
```json
|
||||
@@ -120,6 +120,12 @@ axolotl train my_training.yml
|
||||
|
||||
## Common Tasks {#sec-common-tasks}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
The same yaml file is used for training, inference, and merging.
|
||||
|
||||
:::
|
||||
|
||||
### Testing Your Model {#sec-testing}
|
||||
|
||||
After training, test your model:
|
||||
@@ -128,6 +134,16 @@ After training, test your model:
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
More details can be found in [Inference](inference.qmd).
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Preprocessing Data {#sec-preprocessing}
|
||||
|
||||
For large datasets, preprocess first:
|
||||
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||
|
||||
Launch a Gradio interface:
|
||||
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||
|
||||
### Merging LoRA weights {#sec-merging-lora}
|
||||
|
||||
To merge the LoRA weights back into the base model, run:
|
||||
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||
|
||||
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, you might want to:
|
||||
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
|
||||
@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_beta3:
|
||||
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.adam_epsilon2:
|
||||
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
|
||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.rl is not RLType.GRPO:
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -14,7 +13,7 @@ from accelerate.utils import (
|
||||
broadcast_object_list,
|
||||
gather,
|
||||
gather_object,
|
||||
is_peft_model,
|
||||
is_peft_available,
|
||||
)
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch import nn
|
||||
@@ -30,15 +29,13 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_peft_available
|
||||
from trl import GRPOTrainer
|
||||
from trl.data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
)
|
||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
||||
from trl.import_utils import is_deepspeed_available
|
||||
from trl.extras.profiling import profiling_context
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
@@ -52,62 +49,12 @@ if is_peft_available():
|
||||
# pylint: disable=unused-import
|
||||
from peft import PeftConfig
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
)
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
_CONFIG_FOR_DOC,
|
||||
COHERE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -20,15 +20,11 @@ from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -170,10 +162,6 @@ def cce_forward(
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
_CONFIG_FOR_DOC,
|
||||
LLAMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -16,22 +16,12 @@ from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
LLAMA4_INPUTS_DOCSTRING,
|
||||
Llama4CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -160,9 +150,6 @@ def cce_forward(
|
||||
)
|
||||
|
||||
|
||||
@replace_return_docstrings(
|
||||
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
|
||||
@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
_CONFIG_FOR_DOC,
|
||||
MISTRAL_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN2MOE_INPUTS_DOCSTRING,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN2_VL_INPUTS_DOCSTRING,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN3_MOE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if num_items_in_batch is None:
|
||||
num_items_in_batch = -1
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
|
||||
0
src/axolotl/integrations/liger/models/__init__.py
Normal file
0
src/axolotl/integrations/liger/models/__init__.py
Normal file
@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
||||
# @replace_return_docstrings(
|
||||
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
# )
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
from transformers.models.jamba.modeling_jamba import (
|
||||
_CONFIG_FOR_DOC,
|
||||
JAMBA_INPUTS_DOCSTRING,
|
||||
HybridMambaAttentionDynamicCache,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def new_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""MLFlow module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
def should_log_artifacts() -> bool:
|
||||
truths = ["TRUE", "1", "YES"]
|
||||
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
# pylint: disable=duplicate-code
|
||||
"""Callback to save axolotl config to mlflow"""
|
||||
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
if should_log_artifacts():
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||
|
||||
@@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
num_proc=cfg.dataset_processes,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -484,7 +484,7 @@ def get_dataset_wrapper(
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
finally:
|
||||
recursive_kill(vllm_process)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
|
||||
Reference in New Issue
Block a user