feat: support unmask only assistant region (gemma3n for now)

This commit is contained in:
NanoCode012
2025-07-21 17:43:26 +07:00
parent 312832e1fe
commit 213446e078

View File

@@ -5,7 +5,7 @@ from typing import Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Resampling from PIL.Image import Resampling
from torch import Tensor from torch import Tensor, zeros_like
from transformers import ProcessorMixin from transformers import ProcessorMixin
from transformers.image_utils import load_image from transformers.image_utils import load_image
@@ -208,9 +208,18 @@ class ProcessingStrategy:
return processed_examples return processed_examples
def _mask_non_assistant(self, labels: Tensor) -> Tensor:
"""
Mask non assistant regions to -100.
To be implemented per subclass.
"""
return labels
def process_labels(self, input_ids: Tensor) -> Tensor: def process_labels(self, input_ids: Tensor) -> Tensor:
labels = input_ids.clone() labels = input_ids.clone()
labels = self._mask_non_assistant(labels)
# The labels are the input_ids, and we mask the padding tokens in the loss computation # The labels are the input_ids, and we mask the padding tokens in the loss computation
labels[labels == self.processor.tokenizer.pad_token_id] = -100 labels[labels == self.processor.tokenizer.pad_token_id] = -100
@@ -267,8 +276,81 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
class Gemma3nProcessingStrategy(ProcessingStrategy): class Gemma3nProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3n""" """Processing Strategy class for Gemma3n"""
def _mask_non_assistant(self, labels: Tensor) -> Tensor:
def _find_token_sequence(label, start_pos, token_sequence):
"""Check if token_sequence appears at start_pos in label"""
if start_pos + len(token_sequence) > len(label):
return False
if label[start_pos] != token_sequence[0]:
return False
return (
label[start_pos : start_pos + len(token_sequence)].tolist()
== token_sequence
)
def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i):
"""
Find the end of assistant response and update mask accordingly
Returns new position to continue from and whether the end seq is found
"""
k = start_pos
while k < len(label):
if not _find_token_sequence(label, k, assistant_end_tok):
mask[i][k] = 1
k += 1
continue
return k + len(assistant_end_tok), True
return k, False
mask = zeros_like(labels)
assistant_start_str = "<start_of_turn>model"
assistant_end_str = "<end_of_turn>"
include_assistant_start_tok = False
include_assistant_end_tok = True
# str to tokens
assistant_start_tok = self.processor.tokenizer.encode(
assistant_start_str, add_special_tokens=False
)
assistant_end_tok = self.processor.tokenizer.encode(
assistant_end_str, add_special_tokens=False
)
for i, label in enumerate(labels):
j = 0
# while loop through each tok index in labels[i]
while j < len(label):
# Check until match start seq
if not _find_token_sequence(label, j, assistant_start_tok):
j += 1
continue
if include_assistant_start_tok:
mask[i][j : j + len(assistant_start_tok)] = 1
# Find where the assistant response ends
start_of_content = j + len(assistant_start_tok)
end_pos, found_end_seq = _find_assistant_end(
label, start_of_content, assistant_end_tok, mask, i
)
# Include end token if requested
if include_assistant_end_tok and found_end_seq:
mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1
j = end_pos
labels[i][mask[i] == 0] = -100
return labels
def process_labels(self, input_ids): def process_labels(self, input_ids):
labels = input_ids.clone() labels = input_ids.clone()
labels = self._mask_non_assistant(labels)
# Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb
labels[labels == self.processor.tokenizer.pad_token_id] = -100 labels[labels == self.processor.tokenizer.pad_token_id] = -100