Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
9eaae5925a set labels and fix datasets block 2024-12-13 13:04:24 -05:00
Wing Lian
d000851eeb allow pretrain to be used with sft 2024-12-13 12:58:37 -05:00
7 changed files with 6 additions and 6 deletions

View File

@@ -8,8 +8,3 @@ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /worksp
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
tests=$(pytest --collect-only -q tests/e2e/each)
for t in $tests; do
pytest $t
done

View File

@@ -41,6 +41,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
res["labels"] = res["input_ids"].copy()
return res
@@ -49,12 +50,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg):
if cfg.pretraining_dataset:
cfg_ds = cfg.pretraining_dataset
else:
cfg_ds = cfg.datasets
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
text_column=cfg_ds[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64,
)
return strat