EAFT (#3366) [skip ci]
* wip eaft * fix eaft loss fn * adding ref --------- Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
This commit is contained in:
77
examples/eaft/eaft-example.yml
Normal file
77
examples/eaft/eaft-example.yml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
|
||||||
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/eaft-gemma-3-1b
|
||||||
|
|
||||||
|
use_eaft: true
|
||||||
|
eaft_alpha: 1.0
|
||||||
|
eaft_k: 20
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
adapter:
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
eval_batch_size: 1
|
||||||
|
max_steps: 1000
|
||||||
|
evaluation_strategy: "no"
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -373,6 +373,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
|
|
||||||
|
if self.cfg.use_eaft:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.loss.eaft import eaft_loss
|
||||||
|
|
||||||
|
configured_eaft_loss = partial(
|
||||||
|
eaft_loss,
|
||||||
|
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
|
||||||
|
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
|
||||||
|
)
|
||||||
|
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
|
|||||||
51
src/axolotl/monkeypatch/loss/eaft.py
Normal file
51
src/axolotl/monkeypatch/loss/eaft.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
eaft (entropy-aware focal training) loss implementation
|
||||||
|
weights examples by entropy approximation from top-k logits
|
||||||
|
|
||||||
|
Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):
|
||||||
|
"""
|
||||||
|
compute eaft loss with entropy weighting
|
||||||
|
|
||||||
|
args:
|
||||||
|
outputs: model outputs containing logits
|
||||||
|
labels: target labels for computing loss
|
||||||
|
num_items_in_batch: for sample packing support
|
||||||
|
alpha: exponent for entropy weighting (default 1.0)
|
||||||
|
k: number of top logits for entropy approximation (default 20)
|
||||||
|
"""
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
vocab_size = shift_logits.size(-1)
|
||||||
|
shift_logits_view = shift_logits.view(-1, vocab_size)
|
||||||
|
shift_labels_view = shift_labels.view(-1)
|
||||||
|
|
||||||
|
mask = shift_labels_view != -100
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
top_k_logits, _ = torch.topk(
|
||||||
|
shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1
|
||||||
|
)
|
||||||
|
top_k_probs = F.softmax(top_k_logits, dim=-1)
|
||||||
|
entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)
|
||||||
|
weights = torch.pow(entropy, alpha)
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])
|
||||||
|
weighted_loss = per_token_loss * weights
|
||||||
|
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss = weighted_loss.sum() / num_items_in_batch
|
||||||
|
else:
|
||||||
|
loss = weighted_loss.mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -676,6 +676,24 @@ class AxolotlInputConfig(
|
|||||||
"description": "Number of chunks to use for chunked cross entropy loss"
|
"description": "Number of chunks to use for chunked cross entropy loss"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
use_eaft: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable Entropy-Aware Focal Training loss (EAFT)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eaft_alpha: float | None = Field(
|
||||||
|
default=1.0,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Exponent for entropy weighting in EAFT (default: 1.0)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eaft_k: int | None = Field(
|
||||||
|
default=20,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of top logits for entropy approximation (default: 20)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tiled_mlp: bool | None = Field(
|
tiled_mlp: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user