Compare commits
4 Commits
swe-rebenc
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b5a9d1d88 | ||
|
|
eb59070040 | ||
|
|
9722aaf7d8 | ||
|
|
c5d20bbd79 |
@@ -12,7 +12,7 @@ packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.2.0
|
||||
transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
|
||||
accelerate==1.12.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.3
|
||||
|
||||
@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
def create_optimizer(self, model=None):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(model=model)
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
opt_model = self.model if model is None else model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
|
||||
@@ -115,6 +115,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
|
||||
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing phi-3.5 with assistant dataset")
|
||||
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
|
||||
"phi35 tokenizer should be a LlamaTokenizer"
|
||||
)
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
phi35_tokenizer,
|
||||
@@ -140,13 +143,13 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
32010, # user
|
||||
22172, 32007, # user eot
|
||||
12199, 32007, # user eot
|
||||
32001, # assistant
|
||||
22172, 32007, # assistant eot
|
||||
12199, 32007, # assistant eot
|
||||
32010, # user
|
||||
1781, 26966, 32007, # user eot
|
||||
16773, 26966, 32007, # user eot
|
||||
32001, # assistant
|
||||
1781, 26966, 32007, # assistant eot
|
||||
16773, 26966, 32007, # assistant eot
|
||||
]
|
||||
expected_labels = [
|
||||
-100, # user
|
||||
@@ -156,7 +159,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
-100, # user
|
||||
-100, -100, -100, # user eot
|
||||
-100, # assistant
|
||||
1781, 26966, 32007, # assistant eot
|
||||
16773, 26966, 32007, # assistant eot
|
||||
]
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
|
||||
@@ -84,7 +84,8 @@ class TestTokenizers:
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||
assert "LlamaTokenizer" in tokenizer.__class__.__name__
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
|
||||
Reference in New Issue
Block a user