Compare commits

...

4 Commits

Author SHA1 Message Date
Wing Lian
3b5a9d1d88 update create_optimizer for updated api 2026-02-19 23:49:32 -05:00
Wing Lian
eb59070040 fix labels 2026-02-19 23:44:46 -05:00
Wing Lian
9722aaf7d8 fix for tokenizers change 2026-02-19 21:52:44 -05:00
Wing Lian
c5d20bbd79 integration branch for transformers#44041 2026-02-19 18:34:13 -05:00
4 changed files with 14 additions and 10 deletions

View File

@@ -12,7 +12,7 @@ packaging==26.0
huggingface_hub>=1.1.7 huggingface_hub>=1.1.7
peft>=0.18.1 peft>=0.18.1
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==5.2.0 transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self, model=None):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.optimizer_cls_and_kwargs is None
): ):
return super().create_optimizer() return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model if model is None else model
if ( if (
not self.optimizer not self.optimizer

View File

@@ -115,6 +115,9 @@ class TestAssistantChatTemplateLlama3:
def test_phi35(self, phi35_tokenizer, assistant_dataset): def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset") LOG.info("Testing phi-3.5 with assistant dataset")
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
"phi35 tokenizer should be a LlamaTokenizer"
)
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
@@ -140,13 +143,13 @@ class TestAssistantChatTemplateLlama3:
# fmt: off # fmt: off
expected_input_ids = [ expected_input_ids = [
32010, # user 32010, # user
22172, 32007, # user eot 12199, 32007, # user eot
32001, # assistant 32001, # assistant
22172, 32007, # assistant eot 12199, 32007, # assistant eot
32010, # user 32010, # user
1781, 26966, 32007, # user eot 16773, 26966, 32007, # user eot
32001, # assistant 32001, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
expected_labels = [ expected_labels = [
-100, # user -100, # user
@@ -156,7 +159,7 @@ class TestAssistantChatTemplateLlama3:
-100, # user -100, # user
-100, -100, -100, # user eot -100, -100, -100, # user eot
-100, # assistant -100, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Expected input_ids: {expected_input_ids}")

View File

@@ -84,7 +84,8 @@ class TestTokenizers:
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404] assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length # ensure reloading the tokenizer again from cfg results in same vocab length