Compare commits

...

5 Commits

Author SHA1 Message Date
Wing Lian
9eaae5925a set labels and fix datasets block 2024-12-13 13:04:24 -05:00
Wing Lian
d000851eeb allow pretrain to be used with sft 2024-12-13 12:58:37 -05:00
Wing Lian
effc4dc409 pin to 4.47.0 (#2180) 2024-12-12 20:17:12 -05:00
Wing Lian
02629c7cdf parity for nightly ci - make sure to install setuptools (#2176) [skip ci] 2024-12-11 20:14:55 -05:00
Wing Lian
78a4aa86d6 evaluation_strategy was fully deprecated in recent release (#2169) [skip ci] 2024-12-11 20:14:24 -05:00
4 changed files with 13 additions and 6 deletions

View File

@@ -44,6 +44,11 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu

View File

@@ -12,7 +12,7 @@ liger-kernel==0.4.2
packaging==23.2
peft==0.14.0
transformers>=4.46.3
transformers==4.47.0
tokenizers>=0.20.1
accelerate==1.2.0
datasets==3.1.0

View File

@@ -41,6 +41,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
res["labels"] = res["input_ids"].copy()
return res
@@ -49,12 +50,16 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg):
if cfg.pretraining_dataset:
cfg_ds = cfg.pretraining_dataset
else:
cfg_ds = cfg.datasets
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
text_column=cfg_ds[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control