Compare commits
8 Commits
mm_mc_chat
...
feat/soap-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a7f048c6b | ||
|
|
76d26366ad | ||
|
|
64fe284765 | ||
|
|
cf0c79d52e | ||
|
|
4ba80a0e5a | ||
|
|
c49682132b | ||
|
|
e46239f8d3 | ||
|
|
05f03b541a |
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -136,4 +136,4 @@ jobs:
|
|||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
11
.github/workflows/tests.yml
vendored
11
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -137,7 +137,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -171,6 +171,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: huggingface-cli scan-cache
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
@@ -229,7 +232,7 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -276,4 +279,4 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb,comet_ml
|
known_third_party=wandb,comet_ml
|
||||||
|
known_local_folder=src,tests
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ load_in_4bit: true
|
|||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
chat_template: gemma3_text
|
chat_template: gemma3
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
@@ -19,7 +19,6 @@ val_set_size: 0.0
|
|||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
dataset_exact_deduplication: true
|
dataset_exact_deduplication: true
|
||||||
test_value: true
|
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ triton>=3.0.0
|
|||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.5
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
@@ -15,7 +15,7 @@ peft==0.15.0
|
|||||||
transformers==4.50.0
|
transformers==4.50.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.4.1
|
datasets==3.5.0
|
||||||
deepspeed==0.16.4
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.15.1
|
||||||
|
|
||||||
|
|||||||
@@ -663,6 +663,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "soap":
|
||||||
|
from axolotl.utils.optimizers.soap import SOAP
|
||||||
|
|
||||||
|
optimizer_cls = SOAP
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ import torch
|
|||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
|
from axolotl.utils.distributed import zero_only
|
||||||
|
|
||||||
from ...utils.distributed import zero_only
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import transformers
|
|||||||
from cut_cross_entropy.transformers.utils import (
|
from cut_cross_entropy.transformers.utils import (
|
||||||
PatchOptions,
|
PatchOptions,
|
||||||
TransformersModelT,
|
TransformersModelT,
|
||||||
apply_lce,
|
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.cache_utils import Cache, HybridCache
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
@@ -33,6 +32,8 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -134,25 +135,17 @@ def cce_forward(
|
|||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
assert labels is not None
|
assert labels is not None
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
|
|
||||||
)
|
|
||||||
loss = apply_lce(
|
loss = apply_lce(
|
||||||
hidden_states[:, slice_indices, :],
|
hidden_states[:, slice_indices, :],
|
||||||
self.lm_head.weight,
|
self.lm_head.weight,
|
||||||
labels,
|
labels,
|
||||||
_PATCH_OPTS,
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
**loss_kwargs,
|
**loss_kwargs,
|
||||||
)
|
)
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
logits = hidden_states[:, slice_indices, :]
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
if self.config.final_logit_softcapping is not None:
|
if self.config.final_logit_softcapping is not None:
|
||||||
@@ -353,6 +346,7 @@ def cce_forward_multimodal(
|
|||||||
self.language_model.lm_head.weight,
|
self.language_model.lm_head.weight,
|
||||||
labels,
|
labels,
|
||||||
_PATCH_OPTS,
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
**lm_kwargs,
|
**lm_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
"""Monkeypatch for apply_lce to add softcap."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from cut_cross_entropy import linear_cross_entropy
|
||||||
|
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lce(
|
||||||
|
e: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
opts: PatchOptions,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
softcap: float | None = None,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||||
|
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||||
|
cce_kwargs = opts.to_kwargs()
|
||||||
|
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||||
|
cce_kwargs["reduction"] = "sum"
|
||||||
|
else:
|
||||||
|
num_items_in_batch = None
|
||||||
|
|
||||||
|
loss = linear_cross_entropy(
|
||||||
|
e,
|
||||||
|
c,
|
||||||
|
labels.to(e.device),
|
||||||
|
bias=bias,
|
||||||
|
shift=True,
|
||||||
|
softcap=softcap,
|
||||||
|
**cce_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -20,6 +20,26 @@ liger_layer_norm: true
|
|||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
- deepseek_v2
|
||||||
|
- gemma
|
||||||
|
- gemma2
|
||||||
|
- gemma3 (partial support, no support for FLCE yet)
|
||||||
|
- granite
|
||||||
|
- jamba
|
||||||
|
- llama
|
||||||
|
- mistral
|
||||||
|
- mixtral
|
||||||
|
- mllama
|
||||||
|
- mllama_text_model
|
||||||
|
- olmo2
|
||||||
|
- paligemma
|
||||||
|
- phi3
|
||||||
|
- qwen2
|
||||||
|
- qwen2_5_vl
|
||||||
|
- qwen2_vl
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
|
|||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
from transformers.loss.loss_utils import nn
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
@@ -104,15 +114,51 @@ class LigerPlugin(BasePlugin):
|
|||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
# nn.CrossEntropyLoss in the forward method.
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
|
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
|
||||||
|
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||||
|
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||||
|
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||||
|
_liger_rms_norm_wrapper,
|
||||||
|
offset=1.0,
|
||||||
|
casting_mode="gemma",
|
||||||
|
init_fn="zeros",
|
||||||
|
in_place=False,
|
||||||
|
)
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||||
|
|||||||
@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if turn_idx >= len(turns):
|
if turn_idx >= len(turns):
|
||||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||||
|
|
||||||
# mistral does not output message if it contains only system message
|
# mistral/gemma3 does not output message if it contains only system message
|
||||||
if (
|
if (
|
||||||
turn_idx == 0
|
turn_idx == 0
|
||||||
and turns[0].get("role") == "system"
|
and turns[0].get("role") == "system"
|
||||||
and "mistral" in self.tokenizer.name_or_path.lower()
|
and (
|
||||||
|
"mistral" in self.tokenizer.name_or_path.lower()
|
||||||
|
# gemma3 uses gemma tokenizer
|
||||||
|
or "gemma" in self.tokenizer.name_or_path.lower()
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import transformers.modelcard
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
@@ -302,7 +303,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# Defensively push to the hub to ensure the model card is updated
|
# Defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
from huggingface_hub.errors import HFValidationError
|
from huggingface_hub.errors import (
|
||||||
|
HFValidationError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -70,20 +74,25 @@ def load_dataset_w_config(
|
|||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
load_dataset(
|
snapshot_download(
|
||||||
config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
name=config_dataset.name,
|
repo_type="dataset",
|
||||||
streaming=True,
|
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=ds_trust_remote_code,
|
ignore_patterns=["*"],
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
FileNotFoundError,
|
||||||
|
ConnectionError,
|
||||||
|
HFValidationError,
|
||||||
|
ValueError,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
ds_from_cloud = False
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -25,7 +25,7 @@ from peft import (
|
|||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import (
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -39,6 +39,7 @@ from transformers import ( # noqa: F401
|
|||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
Mistral3ForConditionalGeneration,
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
@@ -107,14 +108,21 @@ def get_module_class_from_name(module, name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||||
|
# Set use_cache to False
|
||||||
|
if hasattr(model_config, "use_cache"):
|
||||||
|
model_config.use_cache = False
|
||||||
|
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
if hasattr(model_config, "text_config"):
|
# For multimodal configs, use_cache is set in the text_config
|
||||||
model_config = model_config.text_config
|
if hasattr(model_config, "get_text_config"):
|
||||||
model_config.use_cache = False
|
text_config = model_config.get_text_config()
|
||||||
elif hasattr(model_config, "get_text_config"):
|
if hasattr(text_config, "use_cache"):
|
||||||
model_config = model_config.get_text_config()
|
text_config.use_cache = False
|
||||||
model_config.use_cache = False
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No text config found for multimodal model. Please raise an Issue with model details."
|
||||||
|
)
|
||||||
|
|
||||||
# check if image_size is not set and load image size from model config if available
|
# check if image_size is not set and load image size from model config if available
|
||||||
if (
|
if (
|
||||||
@@ -523,14 +531,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
if cfg.is_multimodal:
|
|
||||||
if hasattr(self.model_config, "text_config"):
|
|
||||||
self.text_model_config = self.model_config.text_config
|
|
||||||
else:
|
|
||||||
# for qwen2_vl
|
|
||||||
self.text_model_config = self.model_config.get_text_config()
|
|
||||||
else:
|
|
||||||
self.text_model_config = self.model_config
|
|
||||||
|
|
||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -947,8 +947,6 @@ class ModelLoader:
|
|||||||
quantization_config = (
|
quantization_config = (
|
||||||
quantization_config or self.model_kwargs["quantization_config"]
|
quantization_config or self.model_kwargs["quantization_config"]
|
||||||
)
|
)
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = load_sharded_model_quant(
|
self.model = load_sharded_model_quant(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@@ -969,9 +967,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
|
|
||||||
# Load model with random initialization if specified
|
# Load model with random initialization if specified
|
||||||
if self.cfg.random_init_weights:
|
if self.cfg.random_init_weights:
|
||||||
# AutoModel classes support the from_config method
|
# AutoModel classes support the from_config method
|
||||||
@@ -1026,8 +1021,6 @@ class ModelLoader:
|
|||||||
and self.model_type != "AutoModelForCausalLM"
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
@@ -1043,25 +1036,7 @@ class ModelLoader:
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
|
||||||
# when training starts
|
|
||||||
if (
|
|
||||||
hasattr(self.text_model_config, "max_seq_len")
|
|
||||||
and self.text_model_config.max_seq_len
|
|
||||||
and self.cfg.sequence_len > self.text_model_config.max_seq_len
|
|
||||||
):
|
|
||||||
self.text_model_config.max_seq_len = self.cfg.sequence_len
|
|
||||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
|
||||||
elif (
|
|
||||||
hasattr(self.text_model_config, "max_sequence_length")
|
|
||||||
and self.text_model_config.max_sequence_length
|
|
||||||
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
|
|
||||||
):
|
|
||||||
self.text_model_config.max_sequence_length = self.cfg.sequence_len
|
|
||||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -1080,8 +1055,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -1346,8 +1319,6 @@ class ModelLoader:
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
self.model.config.use_cache = False
|
|
||||||
|
|
||||||
if self.cfg.flash_optimum:
|
if self.cfg.flash_optimum:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
|||||||
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Nikhil Vyas
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
# Copied from https://github.com/nikhilvyas/SOAP
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# Parts of the code are modifications of Pytorch's AdamW optimizer
|
||||||
|
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
|
||||||
|
|
||||||
|
|
||||||
|
class SOAP(optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
params (`Iterable[nn.parameter.Parameter]`):
|
||||||
|
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||||
|
lr (`float`, *optional*, defaults to 0.003):
|
||||||
|
The learning rate to use.
|
||||||
|
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||||
|
Adam's betas parameters (b1, b2).
|
||||||
|
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||||
|
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
||||||
|
eps (`float`, *optional*, defaults to 1e-08):
|
||||||
|
Adam's epsilon for numerical stability.
|
||||||
|
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||||
|
precondition_frequency (`int`, *optional*, defaults to 10):
|
||||||
|
How often to update the preconditioner.
|
||||||
|
max_precond_dim (`int`, *optional*, defaults to 10000):
|
||||||
|
Maximum dimension of the preconditioner.
|
||||||
|
Set to 10000, so that we exclude most common vocab sizes while including layers.
|
||||||
|
merge_dims (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to merge dimensions of the preconditioner.
|
||||||
|
precondition_1d (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to precondition 1D gradients.
|
||||||
|
normalize_grads (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to normalize gradients per layer.
|
||||||
|
Helps at large precondition_frequency (~100 in our experiments),
|
||||||
|
but hurts performance at small precondition_frequency (~10 in our experiments).
|
||||||
|
data_format (`str`, *optional*, defaults to `channels_first`):
|
||||||
|
Data format of the input for convolutional layers.
|
||||||
|
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
|
||||||
|
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to use bias correction in Adam.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr: float = 3e-3,
|
||||||
|
betas=(0.95, 0.95),
|
||||||
|
shampoo_beta: float = -1,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
precondition_frequency: int = 10,
|
||||||
|
max_precond_dim: int = 10000, #
|
||||||
|
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
|
||||||
|
precondition_1d: bool = False,
|
||||||
|
normalize_grads: bool = False,
|
||||||
|
data_format: str = "channels_first",
|
||||||
|
correct_bias: bool = True,
|
||||||
|
):
|
||||||
|
defaults = {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": betas,
|
||||||
|
"shampoo_beta": shampoo_beta,
|
||||||
|
"eps": eps,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"precondition_frequency": precondition_frequency,
|
||||||
|
"max_precond_dim": max_precond_dim,
|
||||||
|
"merge_dims": merge_dims,
|
||||||
|
"precondition_1d": precondition_1d,
|
||||||
|
"normalize_grads": normalize_grads,
|
||||||
|
"correct_bias": correct_bias,
|
||||||
|
}
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
self._data_format = data_format
|
||||||
|
|
||||||
|
def merge_dims(self, grad, max_precond_dim):
|
||||||
|
"""
|
||||||
|
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
|
||||||
|
"""
|
||||||
|
assert self._data_format in ["channels_first", "channels_last"]
|
||||||
|
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||||
|
grad = grad.permute(0, 3, 1, 2)
|
||||||
|
shape = grad.shape
|
||||||
|
new_shape = []
|
||||||
|
|
||||||
|
curr_shape = 1
|
||||||
|
for sh in shape:
|
||||||
|
temp_shape = curr_shape * sh
|
||||||
|
if temp_shape > max_precond_dim:
|
||||||
|
if curr_shape > 1:
|
||||||
|
new_shape.append(curr_shape)
|
||||||
|
curr_shape = sh
|
||||||
|
else:
|
||||||
|
new_shape.append(sh)
|
||||||
|
curr_shape = 1
|
||||||
|
else:
|
||||||
|
curr_shape = temp_shape
|
||||||
|
|
||||||
|
if curr_shape > 1 or len(new_shape) == 0:
|
||||||
|
new_shape.append(curr_shape)
|
||||||
|
|
||||||
|
new_grad = grad.reshape(new_shape)
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||||
|
"""
|
||||||
|
if closure is None:
|
||||||
|
loss = None
|
||||||
|
else:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
if "step" not in state:
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if "exp_avg" not in state:
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(grad)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||||
|
|
||||||
|
if "Q" not in state:
|
||||||
|
self.init_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
precondition_frequency=group["precondition_frequency"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
shampoo_beta=(
|
||||||
|
group["shampoo_beta"]
|
||||||
|
if group["shampoo_beta"] >= 0
|
||||||
|
else group["betas"][1]
|
||||||
|
),
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
)
|
||||||
|
self.update_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
)
|
||||||
|
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||||
|
|
||||||
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||||
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
|
grad_projected = self.project(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
# In-place operations to update the averages at the same time
|
||||||
|
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
||||||
|
exp_avg_sq.mul_(beta2).add_(
|
||||||
|
grad_projected.square(), alpha=(1.0 - beta2)
|
||||||
|
)
|
||||||
|
|
||||||
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
|
|
||||||
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||||
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
|
# exp_avg_projected = self.project(
|
||||||
|
# exp_avg,
|
||||||
|
# state,
|
||||||
|
# merge_dims=group["merge_dims"],
|
||||||
|
# max_precond_dim=group["max_precond_dim"],
|
||||||
|
# )
|
||||||
|
exp_avg_projected = exp_avg
|
||||||
|
|
||||||
|
step_size = group["lr"]
|
||||||
|
if group["correct_bias"]:
|
||||||
|
bias_correction1 = 1.0 - beta1 ** (state["step"])
|
||||||
|
bias_correction2 = 1.0 - beta2 ** (state["step"])
|
||||||
|
step_size = step_size * (bias_correction2**0.5) / bias_correction1
|
||||||
|
|
||||||
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
||||||
|
# to the original space
|
||||||
|
norm_grad = self.project_back(
|
||||||
|
exp_avg_projected / denom,
|
||||||
|
state,
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if group["normalize_grads"]:
|
||||||
|
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
|
||||||
|
|
||||||
|
p.add_(norm_grad, alpha=-step_size)
|
||||||
|
|
||||||
|
# From AdamW code: Just adding the square of the weights to the loss function is *not*
|
||||||
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
|
# since that will interact with the m and v parameters in strange ways.
|
||||||
|
#
|
||||||
|
# Instead we want to decay the weights in a manner that doesn't interact
|
||||||
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
|
# Add weight decay at the end (fixed version)
|
||||||
|
if group["weight_decay"] > 0.0:
|
||||||
|
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||||
|
|
||||||
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
||||||
|
self.update_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def init_preconditioner(
|
||||||
|
self,
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
precondition_frequency=10,
|
||||||
|
shampoo_beta=0.95,
|
||||||
|
max_precond_dim=10000,
|
||||||
|
precondition_1d=False,
|
||||||
|
merge_dims=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the preconditioner matrices (L and R in the paper).
|
||||||
|
"""
|
||||||
|
state["GG"] = (
|
||||||
|
[]
|
||||||
|
) # Will hold all the preconditioner matrices (L and R in the paper).
|
||||||
|
if grad.dim() == 1:
|
||||||
|
if not precondition_1d or grad.shape[0] > max_precond_dim:
|
||||||
|
state["GG"].append([])
|
||||||
|
else:
|
||||||
|
state["GG"].append(
|
||||||
|
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if merge_dims:
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
|
||||||
|
for sh in grad.shape:
|
||||||
|
if sh > max_precond_dim:
|
||||||
|
state["GG"].append([])
|
||||||
|
else:
|
||||||
|
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
|
||||||
|
|
||||||
|
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
|
||||||
|
state["precondition_frequency"] = precondition_frequency
|
||||||
|
state["shampoo_beta"] = shampoo_beta
|
||||||
|
|
||||||
|
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||||
|
"""
|
||||||
|
Projects the gradient to the eigenbases of the preconditioner.
|
||||||
|
"""
|
||||||
|
original_shape = grad.shape
|
||||||
|
if merge_dims:
|
||||||
|
if grad.dim() == 4 and self._data_format == "channels_last":
|
||||||
|
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
|
||||||
|
for mat in state["Q"]:
|
||||||
|
if len(mat) > 0:
|
||||||
|
grad = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
mat,
|
||||||
|
dims=[[0], [0]],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||||
|
grad = grad.permute(permute_order)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||||
|
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
grad = grad.reshape(original_shape)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def update_preconditioner(
|
||||||
|
self,
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=10000,
|
||||||
|
merge_dims=False,
|
||||||
|
precondition_1d=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||||
|
"""
|
||||||
|
if state["Q"] is not None:
|
||||||
|
state["exp_avg"] = self.project_back(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
|
if grad.dim() == 1:
|
||||||
|
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||||
|
state["GG"][0].lerp_(
|
||||||
|
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if merge_dims:
|
||||||
|
new_grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
for idx, sh in enumerate(new_grad.shape):
|
||||||
|
if sh <= max_precond_dim:
|
||||||
|
outer_product = torch.tensordot(
|
||||||
|
new_grad,
|
||||||
|
new_grad,
|
||||||
|
dims=[
|
||||||
|
[
|
||||||
|
*chain(
|
||||||
|
range(idx), range(idx + 1, len(new_grad.shape))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
]
|
||||||
|
* 2,
|
||||||
|
)
|
||||||
|
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||||
|
else:
|
||||||
|
for idx, sh in enumerate(grad.shape):
|
||||||
|
if sh <= max_precond_dim:
|
||||||
|
outer_product = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
grad,
|
||||||
|
# Contracts across all dimensions except for k.
|
||||||
|
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
|
||||||
|
* 2,
|
||||||
|
)
|
||||||
|
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||||
|
|
||||||
|
if state["Q"] is None:
|
||||||
|
state["Q"] = self.get_orthogonal_matrix(state["GG"])
|
||||||
|
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
|
||||||
|
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||||
|
state, max_precond_dim, merge_dims
|
||||||
|
)
|
||||||
|
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
||||||
|
|
||||||
|
if state["step"] > 0:
|
||||||
|
state["exp_avg"] = self.project(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||||
|
"""
|
||||||
|
Projects the gradient back to the original space.
|
||||||
|
"""
|
||||||
|
original_shape = grad.shape
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||||
|
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
for mat in state["Q"]:
|
||||||
|
if len(mat) > 0:
|
||||||
|
grad = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
mat,
|
||||||
|
dims=[[0], [1]],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||||
|
grad = grad.permute(permute_order)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||||
|
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
grad = grad.reshape(original_shape)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def get_orthogonal_matrix(self, mat):
|
||||||
|
"""
|
||||||
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
||||||
|
"""
|
||||||
|
matrix = []
|
||||||
|
for m in mat:
|
||||||
|
if len(m) == 0:
|
||||||
|
matrix.append([])
|
||||||
|
continue
|
||||||
|
if m.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = m.data.dtype
|
||||||
|
original_device = m.data.device
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix.append(m.data)
|
||||||
|
|
||||||
|
final = []
|
||||||
|
for m in matrix:
|
||||||
|
if len(m) == 0:
|
||||||
|
final.append([])
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
_, Q = torch.linalg.eigh(
|
||||||
|
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||||
|
)
|
||||||
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
|
_, Q = torch.linalg.eigh(
|
||||||
|
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||||
|
)
|
||||||
|
Q = Q.to(m.dtype)
|
||||||
|
Q = torch.flip(Q, [1])
|
||||||
|
|
||||||
|
if not float_data:
|
||||||
|
Q = Q.to(original_device).type(original_type)
|
||||||
|
final.append(Q)
|
||||||
|
return final
|
||||||
|
|
||||||
|
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
|
||||||
|
"""
|
||||||
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
||||||
|
followed by torch.linalg.qr decomposition.
|
||||||
|
"""
|
||||||
|
precond_list = state["GG"]
|
||||||
|
orth_list = state["Q"]
|
||||||
|
|
||||||
|
matrix = []
|
||||||
|
orth_matrix = []
|
||||||
|
for m, o in zip(precond_list, orth_list):
|
||||||
|
if len(m) == 0:
|
||||||
|
matrix.append([])
|
||||||
|
orth_matrix.append([])
|
||||||
|
continue
|
||||||
|
if m.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = m.data.dtype
|
||||||
|
original_device = m.data.device
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
orth_matrix.append(o.data.float())
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
orth_matrix.append(o.data.float())
|
||||||
|
|
||||||
|
orig_shape = state["exp_avg_sq"].shape
|
||||||
|
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||||
|
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
|
||||||
|
if merge_dims:
|
||||||
|
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
|
|
||||||
|
final = []
|
||||||
|
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
|
||||||
|
if len(m) == 0:
|
||||||
|
final.append([])
|
||||||
|
continue
|
||||||
|
est_eig = torch.diag(o.T @ m @ o)
|
||||||
|
sort_idx = torch.argsort(est_eig, descending=True)
|
||||||
|
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
||||||
|
o = o[:, sort_idx]
|
||||||
|
power_iter = m @ o
|
||||||
|
Q, _ = torch.linalg.qr(power_iter)
|
||||||
|
|
||||||
|
if not float_data:
|
||||||
|
Q = Q.to(original_device).type(original_type)
|
||||||
|
final.append(Q)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||||
|
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
|
||||||
|
|
||||||
|
state["exp_avg_sq"] = exp_avg_sq
|
||||||
|
return final
|
||||||
@@ -52,3 +52,4 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
muon = "muon" # pylint: disable=invalid-name
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
soap = "soap" # pylint: disable=invalid-name
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -11,7 +11,11 @@ import time
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
@@ -25,9 +29,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
except (
|
except (
|
||||||
requests.exceptions.ReadTimeout,
|
requests.exceptions.ReadTimeout,
|
||||||
requests.exceptions.ConnectionError,
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
) as exc:
|
) as exc:
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
time.sleep(delay)
|
wait = 2**attempt * delay # in seconds
|
||||||
|
time.sleep(wait)
|
||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
@@ -37,6 +43,7 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
|
@disable_hf_offline
|
||||||
def snapshot_download_w_retry(*args, **kwargs):
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
@@ -44,19 +51,19 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_68m_random_model():
|
def download_llama_68m_random_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("JackFram/llama-68m")
|
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_qwen_2_5_half_billion_model():
|
def download_qwen_2_5_half_billion_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@@ -101,6 +108,37 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_fozzie_alpaca_dpo_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||||
|
)
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision="ea82cff",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
@@ -109,10 +147,141 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_dpo_pairs_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_tiny_shakespeare_dataset():
|
def download_tiny_shakespeare_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
|
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_deepseek_model_fixture():
|
||||||
|
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_huggyllama_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"huggyllama/llama-7b",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama_1b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Llama-3.2-1B",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama3_8b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama3_8b_instruct_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_phi_35_mini_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_phi_3_medium_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"microsoft/Phi-3-medium-128k-instruct",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mistral_7b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_gemma_2b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"unsloth/gemma-2b-it",
|
||||||
|
revision="703fb4a",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_gemma2_9b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"mlx-community/gemma-2-9b-it-4bit",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mlx_mistral_7b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama2_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Llama-2-7b-hf",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_huggyllama(
|
||||||
|
download_huggyllama_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -178,3 +347,34 @@ def cleanup_monkeypatches():
|
|||||||
module_globals = module_name_tuple[1]
|
module_globals = module_name_tuple[1]
|
||||||
for module_global in module_globals:
|
for module_global in module_globals:
|
||||||
globals().pop(module_global, None)
|
globals().pop(module_global, None)
|
||||||
|
|
||||||
|
|
||||||
|
# # pylint: disable=redefined-outer-name,unused-argument
|
||||||
|
# def test_load_fixtures(
|
||||||
|
# download_smollm2_135m_model,
|
||||||
|
# download_llama_68m_random_model,
|
||||||
|
# download_qwen_2_5_half_billion_model,
|
||||||
|
# download_tatsu_lab_alpaca_dataset,
|
||||||
|
# download_mhenrichsen_alpaca_2k_dataset,
|
||||||
|
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||||
|
# download_mlabonne_finetome_100k_dataset,
|
||||||
|
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
||||||
|
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
|
||||||
|
# download_fozzie_alpaca_dpo_dataset,
|
||||||
|
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
||||||
|
# download_argilla_dpo_pairs_dataset,
|
||||||
|
# download_tiny_shakespeare_dataset,
|
||||||
|
# download_deepseek_model_fixture,
|
||||||
|
# download_huggyllama_model_fixture,
|
||||||
|
# download_llama_1b_model_fixture,
|
||||||
|
# download_llama3_8b_model_fixture,
|
||||||
|
# download_llama3_8b_instruct_model_fixture,
|
||||||
|
# download_phi_35_mini_model_fixture,
|
||||||
|
# download_phi_3_medium_model_fixture,
|
||||||
|
# download_mistral_7b_model_fixture,
|
||||||
|
# download_gemma_2b_model_fixture,
|
||||||
|
# download_gemma2_9b_model_fixture,
|
||||||
|
# download_mlx_mistral_7b_model_fixture,
|
||||||
|
# download_llama2_model_fixture,
|
||||||
|
# ):
|
||||||
|
# pass
|
||||||
|
|||||||
@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
|
|||||||
from axolotl.core.chat.format.chatml import format_message
|
from axolotl.core.chat.format.chatml import format_message
|
||||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline # noqa
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def llama_tokenizer_fixture():
|
def llama_tokenizer_fixture():
|
||||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
@@ -13,6 +12,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="kd_min_cfg")
|
@pytest.fixture(name="kd_min_cfg")
|
||||||
def min_cfg(temp_dir):
|
def min_cfg(temp_dir):
|
||||||
|
|||||||
@@ -2,15 +2,13 @@
|
|||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from e2e.utils import require_torch_2_4_1
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
|
|
||||||
|
|
||||||
class LigerIntegrationTestCase:
|
class LigerIntegrationTestCase:
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from e2e.utils import require_vllm
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import require_vllm
|
||||||
|
|
||||||
|
|
||||||
class TestGRPO:
|
class TestGRPO:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,12 +9,13 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from e2e.utils import check_tensorboard
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,11 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -23,6 +25,7 @@ class TestDeepseekV3:
|
|||||||
Test case for DeepseekV3 models
|
Test case for DeepseekV3 models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sample_packing",
|
"sample_packing",
|
||||||
[True, False],
|
[True, False],
|
||||||
@@ -80,6 +83,7 @@ class TestDeepseekV3:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sample_packing",
|
"sample_packing",
|
||||||
[True, False],
|
[True, False],
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ E2E tests for llama
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from e2e.utils import check_model_output_exists
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_model_output_exists
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -201,3 +201,46 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_soap(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "soap",
|
||||||
|
"adam_beta1": 0.9,
|
||||||
|
"adam_beta2": 0.95,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
85
tests/hf_offline_utils.py
Normal file
85
tests/hf_offline_utils.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
test utils for helpers and decorators
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from huggingface_hub.utils import reset_sessions
|
||||||
|
|
||||||
|
|
||||||
|
def reload_modules(hf_hub_offline):
|
||||||
|
# Force reload of the modules that check this variable
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import huggingface_hub.constants
|
||||||
|
|
||||||
|
# Reload the constants module first, as others depend on it
|
||||||
|
importlib.reload(huggingface_hub.constants)
|
||||||
|
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||||
|
importlib.reload(datasets.config)
|
||||||
|
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
||||||
|
reset_sessions()
|
||||||
|
|
||||||
|
|
||||||
|
def enable_hf_offline(test_func):
|
||||||
|
"""
|
||||||
|
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
|
||||||
|
:param test_func:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(test_func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Save the original value of HF_HUB_OFFLINE environment variable
|
||||||
|
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
|
||||||
|
|
||||||
|
# Set HF_OFFLINE environment variable to True
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||||
|
|
||||||
|
reload_modules(True)
|
||||||
|
try:
|
||||||
|
# Run the test function
|
||||||
|
return test_func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Restore the original value of HF_HUB_OFFLINE environment variable
|
||||||
|
if original_hf_offline is not None:
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
|
||||||
|
reload_modules(bool(original_hf_offline))
|
||||||
|
else:
|
||||||
|
del os.environ["HF_HUB_OFFLINE"]
|
||||||
|
reload_modules(False)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def disable_hf_offline(test_func):
|
||||||
|
"""
|
||||||
|
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
|
||||||
|
:param test_func:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(test_func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Save the original value of HF_HUB_OFFLINE environment variable
|
||||||
|
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
|
||||||
|
|
||||||
|
# Set HF_OFFLINE environment variable to True
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = "0"
|
||||||
|
|
||||||
|
reload_modules(False)
|
||||||
|
try:
|
||||||
|
# Run the test function
|
||||||
|
return test_func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Restore the original value of HF_HUB_OFFLINE environment variable
|
||||||
|
if original_hf_offline is not None:
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
|
||||||
|
reload_modules(bool(original_hf_offline))
|
||||||
|
else:
|
||||||
|
del os.environ["HF_HUB_OFFLINE"]
|
||||||
|
reload_modules(False)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -4,12 +4,13 @@ shared fixtures for prompt strategies tests
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
@@ -108,31 +109,27 @@ def fixture_toolcalling_dataset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_llama3_tokenizer():
|
@enable_hf_offline
|
||||||
hf_hub_download(
|
def fixture_llama3_tokenizer(
|
||||||
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
download_llama3_8b_instruct_model_fixture,
|
||||||
filename="special_tokens_map.json",
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
)
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
|
||||||
filename="tokenizer_config.json",
|
|
||||||
)
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_smollm2_tokenizer():
|
def fixture_smollm2_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_mistralv03_tokenizer():
|
@enable_hf_offline
|
||||||
|
def fixture_mistralv03_tokenizer(
|
||||||
|
download_mlx_mistral_7b_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||||
)
|
)
|
||||||
@@ -140,6 +137,7 @@ def fixture_mistralv03_tokenizer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_phi35_tokenizer():
|
def fixture_phi35_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from axolotl.datasets import TokenizedPromptDataset
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="alpaca_dataset")
|
@pytest.fixture(name="alpaca_dataset")
|
||||||
def fixture_alpaca_dataset():
|
def fixture_alpaca_dataset():
|
||||||
@@ -26,6 +28,7 @@ def fixture_alpaca_dataset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
# pylint: disable=all
|
# pylint: disable=all
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|||||||
@@ -13,8 +13,11 @@ from axolotl.utils.chat_templates import (
|
|||||||
get_chat_template,
|
get_chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import get_chat_template
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -30,12 +32,14 @@ PARAMETRIZE_PARAMS = [
|
|||||||
"mistralv03_tokenizer_chat_template_jinja",
|
"mistralv03_tokenizer_chat_template_jinja",
|
||||||
"[/INST]",
|
"[/INST]",
|
||||||
),
|
),
|
||||||
(
|
# TODO: temporarily skip gemma due to gemma3 template
|
||||||
"gemma2_tokenizer",
|
# Re-enable on new chat_template implementation for perf
|
||||||
"jinja",
|
# (
|
||||||
"gemma2_tokenizer_chat_template_jinja",
|
# "gemma2_tokenizer",
|
||||||
"<end_of_turn>",
|
# "jinja",
|
||||||
),
|
# "gemma2_tokenizer_chat_template_jinja",
|
||||||
|
# "<end_of_turn>",
|
||||||
|
# ),
|
||||||
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -93,7 +97,11 @@ class TestChatTemplateConfigurations:
|
|||||||
if (
|
if (
|
||||||
turn_idx == 0
|
turn_idx == 0
|
||||||
and turn.get("from") in ["system", "context"]
|
and turn.get("from") in ["system", "context"]
|
||||||
and "mistral" in tokenizer.name_or_path.lower()
|
and (
|
||||||
|
"mistral" in tokenizer.name_or_path.lower()
|
||||||
|
or "gemma"
|
||||||
|
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
|
||||||
|
)
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
start_idx == -1 and end_idx == -1
|
start_idx == -1 and end_idx == -1
|
||||||
@@ -101,6 +109,7 @@ class TestChatTemplateConfigurations:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_train_on_inputs_true(
|
def test_train_on_inputs_true(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
@@ -78,15 +80,8 @@ def fixture_custom_assistant_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
|
||||||
def fixture_llama3_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
|
||||||
tokenizer.eos_token = "<|eot_id|>"
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi3_tokenizer")
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_phi3_tokenizer():
|
def fixture_phi3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
||||||
|
|
||||||
@@ -94,6 +89,7 @@ def fixture_phi3_tokenizer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="gemma_tokenizer")
|
@pytest.fixture(name="gemma_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def fixture_gemma_tokenizer():
|
def fixture_gemma_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from axolotl.prompt_strategies.dpo import load as load_dpo
|
|||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_dpo_cfg")
|
@pytest.fixture(name="minimal_dpo_cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
@@ -34,6 +36,8 @@ class TestDPOChatml:
|
|||||||
Test loading DPO preference datasets with chatml formatting
|
Test loading DPO preference datasets with chatml formatting
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
def test_default(self, minimal_dpo_cfg):
|
def test_default(self, minimal_dpo_cfg):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -8,12 +8,15 @@ from transformers import LlamaTokenizer
|
|||||||
|
|
||||||
from axolotl.utils.data import encode_pretraining, md5
|
from axolotl.utils.data import encode_pretraining, md5
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
class TestEncodePretraining(unittest.TestCase):
|
class TestEncodePretraining(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
test class for encode pretraining and md5 helper
|
test class for encode pretraining and md5 helper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.add_special_tokens(
|
self.tokenizer.add_special_tokens(
|
||||||
|
|||||||
@@ -4,31 +4,37 @@ Test dataset loading under various conditions.
|
|||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from conftest import snapshot_download_w_retry
|
import pytest
|
||||||
from constants import (
|
|
||||||
ALPACA_MESSAGES_CONFIG_OG,
|
|
||||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
|
||||||
SPECIAL_TOKENS,
|
|
||||||
)
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.constants import (
|
||||||
|
ALPACA_MESSAGES_CONFIG_OG,
|
||||||
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||||
|
SPECIAL_TOKENS,
|
||||||
|
)
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
class TestDatasetPreparation(unittest.TestCase):
|
|
||||||
|
class TestDatasetPreparation:
|
||||||
"""Test a configured dataloader."""
|
"""Test a configured dataloader."""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
@pytest.fixture
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||||
# Alpaca dataset.
|
yield tokenizer_huggyllama
|
||||||
self.dataset = Dataset.from_list(
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_fixture(self):
|
||||||
|
yield Dataset.from_list(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||||
@@ -38,7 +44,9 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_load_hub(self):
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_load_hub(self, tokenizer):
|
||||||
"""Core use case. Verify that processing data from the hub works"""
|
"""Core use case. Verify that processing data from the hub works"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -55,25 +63,28 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_local_hub(self):
|
@enable_hf_offline
|
||||||
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||||
|
def test_load_local_hub(self, tokenizer):
|
||||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download_w_retry(
|
snapshot_path = snapshot_download(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
)
|
)
|
||||||
|
# offline mode doesn't actually copy it to local_dir, so we
|
||||||
|
# have to copy all the contents in the dir manually from the returned snapshot_path
|
||||||
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
# Right now a local copy that doesn't fully conform to a dataset
|
# Right now a local copy that doesn't fully conform to a dataset
|
||||||
@@ -96,9 +107,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -106,11 +115,12 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
def test_load_from_save_to_disk(self):
|
@enable_hf_offline
|
||||||
|
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
||||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -126,22 +136,21 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_from_dir_of_parquet(self):
|
@enable_hf_offline
|
||||||
|
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||||
self.dataset.to_parquet(tmp_ds_path)
|
dataset_fixture.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -162,22 +171,21 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_from_dir_of_json(self):
|
@enable_hf_offline
|
||||||
|
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||||
self.dataset.to_json(tmp_ds_path)
|
dataset_fixture.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -198,20 +206,19 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_from_single_parquet(self):
|
@enable_hf_offline
|
||||||
|
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||||
self.dataset.to_parquet(tmp_ds_path)
|
dataset_fixture.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -228,20 +235,19 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_from_single_json(self):
|
@enable_hf_offline
|
||||||
|
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a single json file can be loaded."""
|
"""Standard use case. Verify a single json file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||||
self.dataset.to_json(tmp_ds_path)
|
dataset_fixture.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -258,15 +264,15 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
def test_load_hub_with_dpo(self):
|
def test_load_hub_with_dpo(self):
|
||||||
"""Verify that processing dpo data from the hub works"""
|
"""Verify that processing dpo data from the hub works"""
|
||||||
|
|
||||||
@@ -285,7 +291,9 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
def test_load_hub_with_revision(self):
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_load_hub_with_revision(self, tokenizer):
|
||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -307,16 +315,17 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
def test_load_hub_with_revision_with_dpo(self):
|
@enable_hf_offline
|
||||||
|
def test_load_hub_with_revision_with_dpo(
|
||||||
|
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||||
|
):
|
||||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -329,22 +338,34 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
# pylint: disable=duplicate-code
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset:
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.return_value = (
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||||
|
)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
assert "conversation" in train_dataset.features
|
|
||||||
|
|
||||||
def test_load_local_hub_with_revision(self):
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||||
|
def test_load_local_hub_with_revision(self, tokenizer):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download_w_retry(
|
snapshot_path = snapshot_download(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
revision="d05c1cb",
|
revision="d05c1cb",
|
||||||
)
|
)
|
||||||
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -365,9 +386,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -375,17 +394,19 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
def test_loading_local_dataset_folder(self):
|
@enable_hf_offline
|
||||||
|
def test_loading_local_dataset_folder(self, tokenizer):
|
||||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download_w_retry(
|
snapshot_path = snapshot_download(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
)
|
)
|
||||||
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -401,16 +422,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ import hashlib
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
@@ -19,6 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
|
|
||||||
|
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||||
"""
|
"""
|
||||||
@@ -214,13 +216,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
|||||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateRLDataset(unittest.TestCase):
|
class TestDeduplicateRLDataset:
|
||||||
"""Test a configured dataloader with deduplication."""
|
"""Test a configured dataloader with deduplication."""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
@pytest.fixture
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
def cfg(self):
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
fixture = DictDefault(
|
||||||
self.cfg = DictDefault(
|
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -233,34 +234,66 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
yield fixture
|
||||||
|
|
||||||
def test_load_with_deduplication(self):
|
@enable_hf_offline
|
||||||
|
def test_load_with_deduplication(
|
||||||
|
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||||
|
):
|
||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# Load the dataset using the deduplication setting
|
# pylint: disable=duplicate-code
|
||||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
with (
|
||||||
|
patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset,
|
||||||
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
|
):
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.side_effect = [
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
]
|
||||||
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
# Verify that the dataset has been deduplicated
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
|
||||||
|
|
||||||
def test_load_without_deduplication(self):
|
# Verify that the dataset has been deduplicated
|
||||||
"""Verify that loading without deduplication retains duplicates."""
|
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||||
self.cfg.dataset_exact_deduplication = False
|
|
||||||
# Load the dataset without deduplication
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
|
||||||
|
|
||||||
# Verify that the dataset retains duplicates
|
@enable_hf_offline
|
||||||
assert (
|
def test_load_without_deduplication(
|
||||||
len(train_dataset) == 1800 * 2
|
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||||
), "Dataset deduplication occurred when it should not have"
|
):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset,
|
||||||
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
|
):
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.side_effect = [
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
]
|
||||||
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
|
cfg.dataset_exact_deduplication = False
|
||||||
|
# Load the dataset without deduplication
|
||||||
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
|
# Verify that the dataset retains duplicates
|
||||||
|
assert (
|
||||||
|
len(train_dataset) == 1800 * 2
|
||||||
|
), "Dataset deduplication occurred when it should not have"
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateNonRL(unittest.TestCase):
|
class TestDeduplicateNonRL(unittest.TestCase):
|
||||||
"""Test prepare_dataset function with different configurations."""
|
"""Test prepare_dataset function with different configurations."""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "huggyllama/llama-7b",
|
"base_model": "huggyllama/llama-7b",
|
||||||
@@ -286,6 +319,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(self.cfg_1)
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = True
|
self.cfg_1.dataset_exact_deduplication = True
|
||||||
@@ -311,6 +346,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"Train dataset should have 2000 samples after deduplication.",
|
"Train dataset should have 2000 samples after deduplication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
def test_prepare_dataset_with_deduplication_eval(self):
|
def test_prepare_dataset_with_deduplication_eval(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = True
|
self.cfg_1.dataset_exact_deduplication = True
|
||||||
@@ -336,6 +373,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"Eval dataset should have 2000 samples after deduplication.",
|
"Eval dataset should have 2000 samples after deduplication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
|
@enable_hf_offline
|
||||||
def test_prepare_dataset_without_deduplication(self):
|
def test_prepare_dataset_without_deduplication(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = False
|
self.cfg_1.dataset_exact_deduplication = False
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
@@ -25,6 +27,7 @@ class TestBatchedSamplerPacking:
|
|||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"batch_size, num_workers",
|
"batch_size, num_workers",
|
||||||
[
|
[
|
||||||
@@ -35,11 +38,12 @@ class TestBatchedSamplerPacking:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||||
|
@enable_hf_offline
|
||||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"Trelis/tiny-shakespeare",
|
"winglian/tiny-shakespeare",
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,15 @@ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter
|
from axolotl.prompters import AlpacaPrompter
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
class TestPacking(unittest.TestCase):
|
class TestPacking(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test class for packing dataset sequences
|
Test class for packing dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
|||||||
@@ -1,43 +1,60 @@
|
|||||||
"""Module for testing streaming dataset sequence packing"""
|
"""Module for testing streaming dataset sequence packing"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import unittest
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import IterableDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainingPacking(unittest.TestCase):
|
class TestPretrainingPacking:
|
||||||
"""
|
"""
|
||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
@pytest.fixture
|
||||||
# pylint: disable=duplicate-code
|
def random_text(self):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
# seed with random.seed(0) for reproducibility
|
||||||
self.tokenizer.pad_token = "</s>"
|
random.seed(0)
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=3, delay=5)
|
# generate row of random text with "words" of between 2 and 10 characters and
|
||||||
def test_packing_stream_dataset(self):
|
# between 400 to 1200 characters per line
|
||||||
# pylint: disable=duplicate-code
|
def rand_txt():
|
||||||
dataset = load_dataset(
|
return " ".join(
|
||||||
"allenai/c4",
|
[
|
||||||
"en",
|
"".join(
|
||||||
streaming=True,
|
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
|
||||||
)["train"]
|
)
|
||||||
|
for _ in range(random.randint(50, 200))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a list of 2000 random texts rather than just using it within the
|
||||||
|
# generator so the test runs faster
|
||||||
|
data = [rand_txt() for _ in range(500)]
|
||||||
|
|
||||||
|
# Create an IterableDataset
|
||||||
|
def generator():
|
||||||
|
for row in data:
|
||||||
|
yield {"text": row}
|
||||||
|
|
||||||
|
return IterableDataset.from_generator(generator)
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=1, delay=5)
|
||||||
|
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
||||||
|
dataset = random_text
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"pretraining_dataset": [
|
"pretraining_dataset": [
|
||||||
{
|
{
|
||||||
"path": "allenai/c4",
|
"path": "winglian/tiny-shakespeare",
|
||||||
"name": "en",
|
|
||||||
"type": "pretrain",
|
"type": "pretrain",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -54,15 +71,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
self.tokenizer,
|
tokenizer_huggyllama,
|
||||||
cfg,
|
cfg,
|
||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
original_bsz = cfg.micro_batch_size
|
original_bsz = cfg.micro_batch_size
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
self.tokenizer,
|
tokenizer_huggyllama,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
@@ -78,7 +96,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
idx = 0
|
idx = 0
|
||||||
for data in trainer_loader:
|
for data in trainer_loader:
|
||||||
if idx > 10:
|
if idx > 3:
|
||||||
break
|
break
|
||||||
assert data["input_ids"].shape == torch.Size(
|
assert data["input_ids"].shape == torch.Size(
|
||||||
[1, original_bsz * cfg.sequence_len]
|
[1, original_bsz * cfg.sequence_len]
|
||||||
@@ -95,7 +113,3 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
# [1, original_bsz * cfg.sequence_len]
|
# [1, original_bsz * cfg.sequence_len]
|
||||||
# )
|
# )
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
@@ -22,6 +23,8 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -63,6 +66,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies.
|
Test class for prompt tokenization strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
@@ -119,6 +123,7 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
@@ -160,6 +165,7 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||||
@@ -238,6 +244,7 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|||||||
class OrpoTokenizationTest(unittest.TestCase):
|
class OrpoTokenizationTest(unittest.TestCase):
|
||||||
"""test case for the ORPO tokenization"""
|
"""test case for the ORPO tokenization"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
@@ -262,6 +269,7 @@ class OrpoTokenizationTest(unittest.TestCase):
|
|||||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||||
).select([0])
|
).select([0])
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
def test_orpo_integration(self):
|
def test_orpo_integration(self):
|
||||||
strat = load(
|
strat = load(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
@@ -9,12 +9,15 @@ import pytest
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizers:
|
class TestTokenizers:
|
||||||
"""
|
"""
|
||||||
test class for the load_tokenizer fn
|
test class for the load_tokenizer fn
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -24,6 +27,7 @@ class TestTokenizers:
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" in tokenizer.__class__.__name__
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_dont_use_fast(self):
|
def test_dont_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -34,6 +38,7 @@ class TestTokenizers:
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_special_tokens_modules_to_save(self):
|
def test_special_tokens_modules_to_save(self):
|
||||||
# setting special_tokens to new token
|
# setting special_tokens to new token
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -68,6 +73,7 @@ class TestTokenizers:
|
|||||||
)
|
)
|
||||||
load_tokenizer(cfg)
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_add_additional_special_tokens(self):
|
def test_add_additional_special_tokens(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -83,6 +89,7 @@ class TestTokenizers:
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert len(tokenizer) == 32001
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_added_tokens_overrides(self, temp_dir):
|
def test_added_tokens_overrides(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -104,11 +111,12 @@ class TestTokenizers:
|
|||||||
128042
|
128042
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
# use with tokenizer that has reserved_tokens in added_tokens
|
# use with tokenizer that has reserved_tokens in added_tokens
|
||||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
}
|
}
|
||||||
|
|||||||
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
Reference in New Issue
Block a user