Compare commits

..

1 Commits

Author SHA1 Message Date
Dan Saunders
d657ff9c94 Update README.md
Quick fix. Local `base_model` paths need to have a trailing `/`.
2024-12-12 15:01:29 -05:00
3 changed files with 3 additions and 8 deletions

View File

@@ -478,7 +478,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- model - model
```yaml ```yaml
base_model: ./llama-7b-hf # local or huggingface repo base_model: ./llama-7b-hf/ # local or huggingface repo
``` ```
Note: The code will load the right architecture. Note: The code will load the right architecture.

View File

@@ -12,7 +12,7 @@ liger-kernel==0.4.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.0 transformers>=4.46.3
tokenizers>=0.20.1 tokenizers>=0.20.1
accelerate==1.2.0 accelerate==1.2.0
datasets==3.1.0 datasets==3.1.0

View File

@@ -41,7 +41,6 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"] seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
] ]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]] res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
res["labels"] = res["input_ids"].copy()
return res return res
@@ -50,16 +49,12 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg): def load(tokenizer, cfg):
if cfg.pretraining_dataset:
cfg_ds = cfg.pretraining_dataset
else:
cfg_ds = cfg.datasets
strat = PretrainTokenizationStrategy( strat = PretrainTokenizationStrategy(
PretrainTokenizer(), PretrainTokenizer(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
text_column=cfg_ds[0]["text_column"] or "text", text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
return strat return strat