diff --git a/examples/eaft/eaft-example.yml b/examples/eaft/eaft-example.yml new file mode 100644 index 000000000..fed4179d2 --- /dev/null +++ b/examples/eaft/eaft-example.yml @@ -0,0 +1,77 @@ +base_model: google/gemma-3-1b-it + +model_type: Gemma3ForCausalLM +cls_model_config: Gemma3TextConfig + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3 +eot_tokens: + - + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: +val_set_size: 0 +output_dir: ./outputs/eaft-gemma-3-1b + +use_eaft: true +eaft_alpha: 1.0 +eaft_k: 20 + +sequence_len: 1024 +sample_packing: false + +adapter: +lora_model_dir: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +eval_batch_size: 1 +max_steps: 1000 +evaluation_strategy: "no" +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +weight_decay: 0.0 +debug: +deepspeed: +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 3a9f8ba1b..09bcff450 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -373,6 +373,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = multiple + if self.cfg.use_eaft: + from functools import partial + + from axolotl.monkeypatch.loss.eaft import eaft_loss + + configured_eaft_loss = partial( + eaft_loss, + alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0, + k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20, + ) + trainer_kwargs["compute_loss_func"] = configured_eaft_loss + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( diff --git a/src/axolotl/monkeypatch/loss/eaft.py b/src/axolotl/monkeypatch/loss/eaft.py new file mode 100644 index 000000000..150d4a005 --- /dev/null +++ b/src/axolotl/monkeypatch/loss/eaft.py @@ -0,0 +1,51 @@ +""" +eaft (entropy-aware focal training) loss implementation +weights examples by entropy approximation from top-k logits + +Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py +""" + +import torch +import torch.nn.functional as F + + +def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20): + """ + compute eaft loss with entropy weighting + + args: + outputs: model outputs containing logits + labels: target labels for computing loss + num_items_in_batch: for sample packing support + alpha: exponent for entropy weighting (default 1.0) + k: number of top logits for entropy approximation (default 20) + """ + logits = outputs.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + vocab_size = shift_logits.size(-1) + shift_logits_view = shift_logits.view(-1, vocab_size) + shift_labels_view = shift_labels.view(-1) + + mask = shift_labels_view != -100 + + with torch.no_grad(): + top_k_logits, _ = torch.topk( + shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1 + ) + top_k_probs = F.softmax(top_k_logits, dim=-1) + entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1) + weights = torch.pow(entropy, alpha) + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask]) + weighted_loss = per_token_loss * weights + + if num_items_in_batch is not None: + loss = weighted_loss.sum() / num_items_in_batch + else: + loss = weighted_loss.mean() + + return loss diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index da21df7aa..3621c0d89 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -676,6 +676,24 @@ class AxolotlInputConfig( "description": "Number of chunks to use for chunked cross entropy loss" }, ) + use_eaft: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable Entropy-Aware Focal Training loss (EAFT)" + }, + ) + eaft_alpha: float | None = Field( + default=1.0, + json_schema_extra={ + "description": "Exponent for entropy weighting in EAFT (default: 1.0)" + }, + ) + eaft_k: int | None = Field( + default=20, + json_schema_extra={ + "description": "Number of top logits for entropy approximation (default: 20)" + }, + ) tiled_mlp: bool | None = Field( default=None,