Compare commits
2 Commits
pytest-eac
...
pretrain-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9eaae5925a | ||
|
|
d000851eeb |
@@ -8,8 +8,3 @@ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /worksp
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
tests=$(pytest --collect-only -q tests/e2e/each)
|
||||
for t in $tests; do
|
||||
pytest $t
|
||||
done
|
||||
|
||||
@@ -41,6 +41,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
||||
]
|
||||
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
||||
res["labels"] = res["input_ids"].copy()
|
||||
|
||||
return res
|
||||
|
||||
@@ -49,12 +50,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
if cfg.pretraining_dataset:
|
||||
cfg_ds = cfg.pretraining_dataset
|
||||
else:
|
||||
cfg_ds = cfg.datasets
|
||||
strat = PretrainTokenizationStrategy(
|
||||
PretrainTokenizer(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
|
||||
text_column=cfg_ds[0]["text_column"] or "text",
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
return strat
|
||||
|
||||
Reference in New Issue
Block a user