Compare commits
12 Commits
sp-restore
...
fix/gemma3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8eba033dc4 | ||
|
|
a9c0f43202 | ||
|
|
a1a740608d | ||
|
|
ec15a7a691 | ||
|
|
0a7a216b60 | ||
|
|
d8280d45c1 | ||
|
|
24f2887e87 | ||
|
|
29289a4de9 | ||
|
|
a24957fa04 | ||
|
|
927bf530bc | ||
|
|
18954ba100 | ||
|
|
d8cf66edbd |
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.2.0
|
rev: 7.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ Features:
|
|||||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ order: 3
|
|||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"conversations": [{"role": "...", "content": "..."}]}
|
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||||
```
|
```
|
||||||
|
|
||||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -75,13 +75,17 @@ def load_datasets(
|
|||||||
|
|
||||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
text_only = cli_args.debug_text_only if cli_args else False
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
train_samples = sample_dataset(train_dataset, num_examples)
|
try:
|
||||||
check_dataset_labels(
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
train_samples,
|
check_dataset_labels(
|
||||||
tokenizer,
|
train_samples,
|
||||||
num_examples=num_examples,
|
tokenizer,
|
||||||
text_only=text_only,
|
num_examples=num_examples,
|
||||||
)
|
text_only=text_only,
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# can't sample iterable datasets
|
||||||
|
pass
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
LOG.info("printing prompters...")
|
||||||
for prompter in prompters:
|
for prompter in prompters:
|
||||||
|
|||||||
@@ -413,7 +413,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
|
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The multiprocessing start method to use."},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -776,6 +776,9 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
|
dest = {"dtype": dist_dtype}
|
||||||
|
if self.cfg.lora_on_cpu:
|
||||||
|
dest["device"] = "cpu"
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -786,4 +789,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(dist_dtype)
|
module.to(**dest)
|
||||||
|
|||||||
@@ -156,8 +156,12 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
model_cls_prefix = "".join(
|
model_cls_prefix = "".join(
|
||||||
[part.capitalize() for part in model_type.split("_")]
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
)
|
)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
if model_type == "gemma3n":
|
||||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||||
|
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
|
||||||
|
else:
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||||
|
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||||
|
|
||||||
return attention_cls
|
return attention_cls
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
# sanity check in case upstream api changes on this
|
||||||
|
assert hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||||
|
), "transformers api changed for _get_unpad_data for flash attention"
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template_kwargs = {
|
chat_template_kwargs = {
|
||||||
"chat_template": self.chat_template,
|
"chat_template": self.chat_template,
|
||||||
"add_generation_prompt": add_generation_prompt,
|
"add_generation_prompt": add_generation_prompt,
|
||||||
|
**self.chat_template_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -223,6 +223,8 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
if cfg.bf16:
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
|
|||||||
remove_columns = []
|
remove_columns = []
|
||||||
if dataset.features is None:
|
if dataset.features is None:
|
||||||
for first_row in dataset:
|
for first_row in dataset:
|
||||||
remove_columns = first_row.keys()
|
remove_columns = list(first_row.keys())
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
remove_columns = dataset.features.keys()
|
remove_columns = list(dataset.features.keys())
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -267,6 +267,7 @@ def encode_packed_pretraining(
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
num_processes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -334,7 +334,10 @@ def _load_raw_datasets(
|
|||||||
dataset = merge_datasets(datasets, cfg)
|
dataset = merge_datasets(datasets, cfg)
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
if split == "test" and cfg.eval_sequence_len:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
|
else:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
|||||||
@@ -148,11 +148,14 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
def drop_long_seq_in_dataset(
|
||||||
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
|
) -> Dataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: Dataset to filter.
|
dataset: Dataset to filter.
|
||||||
|
sequence_len: Maximum length for sequences to keep
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -167,7 +170,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=cfg.sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +190,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def pack_parallel(
|
|||||||
bin_size: int,
|
bin_size: int,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
safe_mode: bool = True,
|
safe_mode: bool = True,
|
||||||
mp_start_method: str | None = "spawn",
|
mp_start_method: str | None = "fork",
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
"""Pack sequences into bins using parallel processing.
|
"""Pack sequences into bins using parallel processing.
|
||||||
|
|
||||||
@@ -260,12 +260,13 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
num_processes: int | None = None, # Number of processes for parallel packing
|
num_processes: int | None = None, # Number of processes for parallel packing
|
||||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||||
|
mp_start_method: str = "fork",
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.bin_size = bin_size
|
self.bin_size = bin_size
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.safe_mode = safe_mode
|
self.safe_mode = safe_mode
|
||||||
|
self.mp_start_method = mp_start_method
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -333,13 +335,15 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
||||||
else:
|
else:
|
||||||
# Use parallel packing
|
# Use parallel packing
|
||||||
|
num_processes = self.num_processes or 1
|
||||||
all_bins = pack_parallel(
|
all_bins = pack_parallel(
|
||||||
lengths,
|
lengths,
|
||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size,
|
bin_size=self.bin_size,
|
||||||
num_processes=self.num_processes,
|
num_processes=min(4, num_processes) if num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
|
mp_start_method=self.mp_start_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map bin indices back to original indices
|
# Map bin indices back to original indices
|
||||||
|
|||||||
@@ -366,6 +366,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
eval_sequence_len: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||||
|
},
|
||||||
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
default=512,
|
default=512,
|
||||||
@@ -393,6 +399,12 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -772,6 +784,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template."
|
||||||
|
},
|
||||||
|
)
|
||||||
eot_tokens: list[str] | None = Field(
|
eot_tokens: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -462,6 +462,20 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def pretrain_with_tps(cls, data):
|
||||||
|
if data.get("pretraining_dataset") and data.get(
|
||||||
|
"include_tokens_per_second", False
|
||||||
|
):
|
||||||
|
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
|
||||||
|
# due to trying to count the number of tokens total in the dataset
|
||||||
|
raise ValueError(
|
||||||
|
"pretraining_dataset and include_tokens_per_second cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing(
|
|||||||
if not skip_position_ids:
|
if not skip_position_ids:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
batched=True,
|
||||||
desc="Add position_id column (Pretraining Sample Packing)",
|
desc="Add position_id column (Pretraining Sample Packing)",
|
||||||
)
|
)
|
||||||
if drop_attention_mask:
|
if drop_attention_mask:
|
||||||
@@ -467,6 +468,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sequential=cfg.sample_packing_sequentially,
|
sequential=cfg.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=cfg.dataset_processes,
|
num_processes=cfg.dataset_processes,
|
||||||
|
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
|
|||||||
Reference in New Issue
Block a user