Compare commits

...

2 Commits

Author SHA1 Message Date
Dan Saunders
c3db6dd307 remove hardcode 2025-08-19 15:41:32 +00:00
Dan Saunders
9a6e9d8d15 no sequence length support 2025-08-19 10:25:37 -04:00
5 changed files with 12 additions and 6 deletions

View File

@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sequence_len:
sample_packing: true
eval_sample_packing: true

View File

@@ -268,7 +268,10 @@ class ModelLoader:
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
and self.model.config.max_position_embeddings
and (
self.cfg.sequence_len is not None
and self.cfg.sequence_len > self.model.config.max_position_embeddings
)
):
LOG.warning(
"increasing model.config.max_position_embeddings from "

View File

@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length
and (self.max_length is None or len(result["input_ids"]) < self.max_length)
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)

View File

@@ -408,7 +408,7 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(
sequence_len: int | None = Field(
default=512,
json_schema_extra={
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"

View File

@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
results = []
for seq in input_ids:
length = len(seq)
if sequence_len is not None:
results.append(min_sequence_len <= length <= sequence_len)
else:
results.append(min_sequence_len <= length)
return results
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update:
cfg.total_num_tokens = total_num_tokens
skip_estimates = cfg.model_config_type == "mamba"
skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
if (
not skip_estimates