Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
87e0fd6b52 feat: add glm 4.7 flash 2026-02-10 18:57:20 +07:00
NanoCode012
2d44432e6c chore: update trinity docs 2026-02-04 18:10:33 +07:00
NanoCode012
57377814e9 feat: update cce for afmoe 2026-02-04 18:00:23 +07:00
14 changed files with 117 additions and 289 deletions

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d\""
]
},
{

View File

@@ -0,0 +1,40 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Run the finetuning example:
```bash
axolotl train examples/glm4.7-flash/glm4.7-flash-qlora.yaml
```
This config uses about X GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Z.ai team recommends `top_p: 0.95`, `temperature: 1.0`, and `max_new_tokens: 131072`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,63 @@
base_model: zai-org/GLM-4.7-Flash
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project: glm-4.7-flash
wandb_entity:
wandb_watch:
wandb_name: qlora
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,13 +1,11 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# CCE - N/A as of now
# plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"'
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"
```
## Usage
@@ -31,6 +31,7 @@ plugins:
## Supported Models
- afmoe
- apertus
- arcee
- cohere

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"`'
)

View File

@@ -39,10 +39,7 @@ class KDPlugin(BasePlugin):
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer, AxolotlOnlineKDTrainer
if cfg.kd_online_server_base_url:
return AxolotlOnlineKDTrainer
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None

View File

@@ -53,9 +53,7 @@ class KDArgs(BaseModel):
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_server_model: str | None = None
kd_online_timeout: int | None = 120
kd_online_max_new_tokens: int | None = 2048
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@@ -76,4 +74,3 @@ class KDTrainingArgsMixin:
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)
kd_online_max_new_tokens: int | None = None

View File

@@ -1,47 +0,0 @@
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.logging import get_logger
# Configure the logger
LOG = get_logger(__name__)
LOG.setLevel("INFO")
class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
@property
def supports_batched(self) -> bool:
# batching doesn't work well for logprob data
return False
def _get_messages(self, prompt):
input_prompt = prompt.get("problem")
return [
{"role": "user", "content": input_prompt},
]
def _tokenize_single_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
tools = self._get_tools(prompt)
input_ids = self.prompter.build_prompt(
turns, tools=tools, add_generation_prompt=True
) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
return {
"input_ids": input_ids,
"prompts": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
class OnlineKDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self, cfg):
return ChatTemplateStrategyWithOnlineKD
load = OnlineKDStrategyLoader()

View File

@@ -16,14 +16,6 @@
KD trainer
"""
import os
from typing import Any, Optional, Union
import requests
import torch
from torch import nn
from transformers import GenerationConfig
from trl.models import unwrap_model_for_generation
from typing_extensions import override
from axolotl.core.trainers.base import AxolotlTrainer
@@ -109,214 +101,3 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss = outputs.loss if hasattr(outputs, "loss") else outputs
return (loss, outputs) if return_outputs else loss
class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.generation_config = GenerationConfig(
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
temperature=1.0,
do_sample=True,
top_k=0,
use_cache=False if kwargs.get("gradient_checkpointing") else True,
pad_token_id=self.processing_class.pad_token_id,
)
# Set custom EOS tokens if they are specified by the model's generation
# config. This is important for models with the Llama 3 chat template,
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
# turns or messages.
if (
hasattr(self.model.generation_config, "eos_token_id")
and self.model.generation_config.eos_token_id is not None
):
self.generation_config.eos_token_id = (
self.model.generation_config.eos_token_id
)
@staticmethod
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
# Generate output with respect to the prompt-only
generated_outputs = model.generate(
input_ids=inputs["prompts"],
attention_mask=inputs.get("prompt_attention_mask", None),
generation_config=generation_config,
return_dict_in_generate=True,
)
# Get the generated token IDs
generated_tokens = generated_outputs.sequences
# Calculate new attention mask
new_attention_mask = torch.ones_like(generated_tokens)
new_labels = generated_tokens.clone()
# If there's pad_token_id, set attention mask to 0 for padding tokens
if pad_token_id is not None:
new_labels[new_labels == pad_token_id] = -100
new_attention_mask[generated_tokens == pad_token_id] = 0
return generated_tokens, new_attention_mask, new_labels
def training_step(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
This method implements the on-policy learning approach described in the GKD paper. With probability
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
the original inputs.
"""
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = (
self.generate_on_policy_outputs(
unwrapped_model,
inputs,
self.generation_config,
self.processing_class.pad_token_id,
)
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
target_token_ids, target_logprobs, target_mask = self.get_teacher_logprobs(
inputs["input_ids"], inputs["labels"]
)
inputs["target_token_ids"] = target_token_ids
inputs["target_logprobs"] = target_logprobs
inputs["target_mask"] = target_mask
loss = super().training_step(model, inputs, num_items_in_batch)
return loss
def get_teacher_logprobs(self, input_ids, labels):
request_body = {
"model": self.axolotl_cfg.kd_online_server_model,
"prompt": input_ids,
"logprobs": self.axolotl_cfg.kd_online_topk,
"echo": True,
"skip_special_tokens": False,
"n": 1,
"max_tokens": 0,
"temperature": 1.0,
}
base_url = self.args.kd_online_server_base_url
api_url = f"{base_url}/v1/completions"
bearer_token = os.getenv("OPENAI_API_KEY")
headers = {"Authorization": f"Bearer {bearer_token}"}
response = requests.post(
api_url, json=request_body, headers=headers, timeout=30
)
prompt_logprobs = response.choices[0].logprobs.top_logprobs[
1:
] # prune first null position
return self.transform_logprobs(input_ids, labels, prompt_logprobs)
def transform_logprobs(self, input_ids, labels, logprobs):
"""
Transform logprobs to target format for KD training
"""
target_seq_len = len(logprobs)
input_seq_len = len(input_ids)
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if labels[position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
# teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
# dim=0, keepdim=True
# )
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
return target_token_ids, target_logprobs, target_mask

View File

@@ -320,7 +320,7 @@ class PatchManager:
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is not None:
if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code

View File

@@ -179,7 +179,7 @@ def check_tensorboard(
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.05,
rtol: float = 0.02,
) -> None:
"""
helper function to parse and check tensorboard logs