diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index f22e601d9..1cb297406 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -5,7 +5,7 @@ from typing import Optional from PIL import Image, ImageOps from PIL.Image import Resampling -from torch import Tensor +from torch import Tensor, zeros_like from transformers import ProcessorMixin from transformers.image_utils import load_image @@ -208,9 +208,18 @@ class ProcessingStrategy: return processed_examples + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + """ + Mask non assistant regions to -100. + To be implemented per subclass. + """ + return labels + def process_labels(self, input_ids: Tensor) -> Tensor: labels = input_ids.clone() + labels = self._mask_non_assistant(labels) + # The labels are the input_ids, and we mask the padding tokens in the loss computation labels[labels == self.processor.tokenizer.pad_token_id] = -100 @@ -267,8 +276,81 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): class Gemma3nProcessingStrategy(ProcessingStrategy): """Processing Strategy class for Gemma3n""" + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + def _find_token_sequence(label, start_pos, token_sequence): + """Check if token_sequence appears at start_pos in label""" + if start_pos + len(token_sequence) > len(label): + return False + if label[start_pos] != token_sequence[0]: + return False + return ( + label[start_pos : start_pos + len(token_sequence)].tolist() + == token_sequence + ) + + def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i): + """ + Find the end of assistant response and update mask accordingly + + Returns new position to continue from and whether the end seq is found + """ + k = start_pos + while k < len(label): + if not _find_token_sequence(label, k, assistant_end_tok): + mask[i][k] = 1 + k += 1 + continue + + return k + len(assistant_end_tok), True + + return k, False + + mask = zeros_like(labels) + + assistant_start_str = "model" + assistant_end_str = "" + include_assistant_start_tok = False + include_assistant_end_tok = True + + # str to tokens + assistant_start_tok = self.processor.tokenizer.encode( + assistant_start_str, add_special_tokens=False + ) + assistant_end_tok = self.processor.tokenizer.encode( + assistant_end_str, add_special_tokens=False + ) + + for i, label in enumerate(labels): + j = 0 + # while loop through each tok index in labels[i] + while j < len(label): + # Check until match start seq + if not _find_token_sequence(label, j, assistant_start_tok): + j += 1 + continue + + if include_assistant_start_tok: + mask[i][j : j + len(assistant_start_tok)] = 1 + + # Find where the assistant response ends + start_of_content = j + len(assistant_start_tok) + end_pos, found_end_seq = _find_assistant_end( + label, start_of_content, assistant_end_tok, mask, i + ) + + # Include end token if requested + if include_assistant_end_tok and found_end_seq: + mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1 + + j = end_pos + + labels[i][mask[i] == 0] = -100 + + return labels + def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb labels[labels == self.processor.tokenizer.pad_token_id] = -100