Compare commits
2 Commits
textui
...
no-seq-len
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3db6dd307 | ||
|
|
9a6e9d8d15 |
@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
|
|||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len:
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
|||||||
@@ -268,7 +268,10 @@ class ModelLoader:
|
|||||||
hasattr(self.model, "config")
|
hasattr(self.model, "config")
|
||||||
and hasattr(self.model.config, "max_position_embeddings")
|
and hasattr(self.model.config, "max_position_embeddings")
|
||||||
and self.model.config.max_position_embeddings
|
and self.model.config.max_position_embeddings
|
||||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
and (
|
||||||
|
self.cfg.sequence_len is not None
|
||||||
|
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"increasing model.config.max_position_embeddings from "
|
"increasing model.config.max_position_embeddings from "
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
and len(result["input_ids"]) < self.max_length
|
and (self.max_length is None or len(result["input_ids"]) < self.max_length)
|
||||||
and add_eos_token
|
and add_eos_token
|
||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
sequence_len: int = Field(
|
sequence_len: int | None = Field(
|
||||||
default=512,
|
default=512,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
|
|||||||
@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
results = []
|
results = []
|
||||||
for seq in input_ids:
|
for seq in input_ids:
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
if sequence_len is not None:
|
||||||
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
|
else:
|
||||||
|
results.append(min_sequence_len <= length)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
skip_estimates = cfg.model_config_type == "mamba"
|
skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_estimates
|
not skip_estimates
|
||||||
|
|||||||
Reference in New Issue
Block a user