Compare commits
3 Commits
online-top
...
feat/glmfl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87e0fd6b52 | ||
|
|
2d44432e6c | ||
|
|
57377814e9 |
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
40
examples/glm4.7-flash/README.md
Normal file
40
examples/glm4.7-flash/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||
|
||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/glm4.7-flash/glm4.7-flash-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about X GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Z.ai team recommends `top_p: 0.95`, `temperature: 1.0`, and `max_new_tokens: 131072`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
63
examples/glm4.7-flash/glm4.7-flash-qlora.yaml
Normal file
63
examples/glm4.7-flash/glm4.7-flash-qlora.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project: glm-4.7-flash
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: qlora
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"'
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,6 +31,7 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- afmoe
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,10 +39,7 @@ class KDPlugin(BasePlugin):
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer, AxolotlOnlineKDTrainer
|
||||
|
||||
if cfg.kd_online_server_base_url:
|
||||
return AxolotlOnlineKDTrainer
|
||||
from .trainer import AxolotlKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
@@ -53,9 +53,7 @@ class KDArgs(BaseModel):
|
||||
kd_online_server: InferenceServerType | None = Field(
|
||||
default_factory=lambda: InferenceServerType.vllm
|
||||
)
|
||||
kd_online_server_model: str | None = None
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_online_max_new_tokens: int | None = 2048
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
)
|
||||
@@ -76,4 +74,3 @@ class KDTrainingArgsMixin:
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
kd_online_max_new_tokens: int | None = None
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
# Configure the logger
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# batching doesn't work well for logprob data
|
||||
return False
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
input_prompt = prompt.get("problem")
|
||||
return [
|
||||
{"role": "user", "content": input_prompt},
|
||||
]
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
tools = self._get_tools(prompt)
|
||||
input_ids = self.prompter.build_prompt(
|
||||
turns, tools=tools, add_generation_prompt=True
|
||||
) # type: ignore
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"prompts": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
|
||||
class OnlineKDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self, cfg):
|
||||
return ChatTemplateStrategyWithOnlineKD
|
||||
|
||||
|
||||
load = OnlineKDStrategyLoader()
|
||||
@@ -16,14 +16,6 @@
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
@@ -109,214 +101,3 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if kwargs.get("gradient_checkpointing") else True,
|
||||
pad_token_id=self.processing_class.pad_token_id,
|
||||
)
|
||||
# Set custom EOS tokens if they are specified by the model's generation
|
||||
# config. This is important for models with the Llama 3 chat template,
|
||||
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
||||
# turns or messages.
|
||||
if (
|
||||
hasattr(self.model.generation_config, "eos_token_id")
|
||||
and self.model.generation_config.eos_token_id is not None
|
||||
):
|
||||
self.generation_config.eos_token_id = (
|
||||
self.model.generation_config.eos_token_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
||||
# Generate output with respect to the prompt-only
|
||||
generated_outputs = model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Get the generated token IDs
|
||||
generated_tokens = generated_outputs.sequences
|
||||
# Calculate new attention mask
|
||||
new_attention_mask = torch.ones_like(generated_tokens)
|
||||
new_labels = generated_tokens.clone()
|
||||
|
||||
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
||||
if pad_token_id is not None:
|
||||
new_labels[new_labels == pad_token_id] = -100
|
||||
new_attention_mask[generated_tokens == pad_token_id] = 0
|
||||
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
||||
|
||||
This method implements the on-policy learning approach described in the GKD paper. With probability
|
||||
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
|
||||
the original inputs.
|
||||
"""
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = (
|
||||
self.generate_on_policy_outputs(
|
||||
unwrapped_model,
|
||||
inputs,
|
||||
self.generation_config,
|
||||
self.processing_class.pad_token_id,
|
||||
)
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
|
||||
target_token_ids, target_logprobs, target_mask = self.get_teacher_logprobs(
|
||||
inputs["input_ids"], inputs["labels"]
|
||||
)
|
||||
inputs["target_token_ids"] = target_token_ids
|
||||
inputs["target_logprobs"] = target_logprobs
|
||||
inputs["target_mask"] = target_mask
|
||||
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
return loss
|
||||
|
||||
def get_teacher_logprobs(self, input_ids, labels):
|
||||
request_body = {
|
||||
"model": self.axolotl_cfg.kd_online_server_model,
|
||||
"prompt": input_ids,
|
||||
"logprobs": self.axolotl_cfg.kd_online_topk,
|
||||
"echo": True,
|
||||
"skip_special_tokens": False,
|
||||
"n": 1,
|
||||
"max_tokens": 0,
|
||||
"temperature": 1.0,
|
||||
}
|
||||
base_url = self.args.kd_online_server_base_url
|
||||
api_url = f"{base_url}/v1/completions"
|
||||
bearer_token = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
headers = {"Authorization": f"Bearer {bearer_token}"}
|
||||
response = requests.post(
|
||||
api_url, json=request_body, headers=headers, timeout=30
|
||||
)
|
||||
prompt_logprobs = response.choices[0].logprobs.top_logprobs[
|
||||
1:
|
||||
] # prune first null position
|
||||
return self.transform_logprobs(input_ids, labels, prompt_logprobs)
|
||||
|
||||
def transform_logprobs(self, input_ids, labels, logprobs):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(input_ids)
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if labels[position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
# teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
# dim=0, keepdim=True
|
||||
# )
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
return target_token_ids, target_logprobs, target_mask
|
||||
|
||||
@@ -320,7 +320,7 @@ class PatchManager:
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is not None:
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# If explicitly set in YAML, prefer that
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ def check_tensorboard(
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.05,
|
||||
rtol: float = 0.02,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
|
||||
Reference in New Issue
Block a user