Compare commits
2 Commits
tool-mpm
...
no-seq-len
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3db6dd307 | ||
|
|
9a6e9d8d15 |
@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sequence_len:
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
|
||||
@@ -268,7 +268,10 @@ class ModelLoader:
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
and self.model.config.max_position_embeddings
|
||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||
and (
|
||||
self.cfg.sequence_len is not None
|
||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||
)
|
||||
):
|
||||
LOG.warning(
|
||||
"increasing model.config.max_position_embeddings from "
|
||||
|
||||
@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and (self.max_length is None or len(result["input_ids"]) < self.max_length)
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
|
||||
@@ -408,7 +408,7 @@ class AxolotlInputConfig(
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
sequence_len: int = Field(
|
||||
sequence_len: int | None = Field(
|
||||
default=512,
|
||||
json_schema_extra={
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
|
||||
@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
if sequence_len is not None:
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
else:
|
||||
results.append(min_sequence_len <= length)
|
||||
return results
|
||||
|
||||
|
||||
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
skip_estimates = cfg.model_config_type == "mamba"
|
||||
skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
|
||||
|
||||
if (
|
||||
not skip_estimates
|
||||
|
||||
Reference in New Issue
Block a user