diff --git a/.nojekyll b/.nojekyll
index 0b18537b2..310a4948e 100644
--- a/.nojekyll
+++ b/.nojekyll
@@ -1 +1 @@
-fc509c68
\ No newline at end of file
+d475bb60
\ No newline at end of file
diff --git a/FAQS.html b/FAQS.html
index 2b056b155..5740e0f93 100644
--- a/FAQS.html
+++ b/FAQS.html
@@ -143,7 +143,7 @@ ul.task-list li input[type="checkbox"] {
@@ -407,7 +413,9 @@ ul.task-list li input[type="checkbox"] {
Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this PR
Will this work with Deepspeed? That’s still a WIP, but setting export ACCELERATE_USE_DEEPSPEED=true should work in some cases
-
Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized. This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
+
Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c
+/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.
+This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
diff --git a/TODO.html b/TODO.html
index e93f0ca95..a9f4a1825 100644
--- a/TODO.html
+++ b/TODO.html
@@ -143,7 +143,7 @@ ul.task-list li input[type="checkbox"] {
xformers appears to be incompatible with ROCm. Apply the following workarounds: - Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return False for SwiGLU availability from xformers. - Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the “SwiGLU” function with a pass statement.
+
xformers appears to be incompatible with ROCm. Apply the following workarounds:
+- Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return False for SwiGLU availability from xformers.
+- Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the “SwiGLU” function with a pass statement.
Registers the plugins for the given configuration.
+
+
+
+
+
check_remote_config
+
cli.config.check_remote_config(config)
+
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
+for it and parse its content, first as JSON, then as YAML (YAML is preferred).
+Finally, the parsed content is written to a local file and its path is returned.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
config
+
Union[str, Path]
+
HTTPS URL to a YAML or JSON file.
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
Union[str, Path]
+
Either the original config if it’s not a valid HTTPS URL, or the path to the
+
+
+
+
Union[str, Path]
+
downloaded remote config.
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
ValueError
+
If the remote configuration is neither valid JSON or YAML.
+
+
+
+
RuntimeError
+
If some request-related exception occurs from the file download.
+
+
+
+
Exception
+
Catch-all for any other exception.
+
+
+
+
+
+
+
choose_config
+
cli.config.choose_config(path)
+
Helper method for choosing a axolotl config YAML file (considering only files
+ending with .yml or .yaml). If more than one config file exists in the passed
+path, the user is prompted to choose one.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
path
+
Path
+
Directory in which config file(s) are stored.
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
str
+
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
Parses axolotl config, CLI args, and calls do_evaluate.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
config
+
Union[Path, str]
+
Path to axolotl config YAML file.
+
Path('examples/')
+
+
+
kwargs
+
+
Additional keyword arguments to override config file values.
+
{}
+
+
+
+
+
+
+
do_evaluate
+
cli.evaluate.do_evaluate(cfg, cli_args)
+
Evaluates a transformers model by first loading the dataset(s) specified in the
+axolotl config, and then calling axolotl.evaluate.evaluate, which computes
+evaluation metrics on the given dataset(s) and writes them to disk.
Parses axolotl config, CLI args, and calls do_inference or do_inference_gradio.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
config
+
Union[Path, str]
+
Path to axolotl config YAML file.
+
Path('examples/')
+
+
+
kwargs
+
+
Additional keyword arguments to override config file values.
+
{}
+
+
+
+
+
+
+
do_inference
+
cli.inference.do_inference(cfg, cli_args)
+
Runs inference on the command line in a loop. User input is accepted, a chat template
+is (optionally) applied, and the model specified in the axolotl config is used to
+generate completions according to a default generation config.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
cfg
+
DictDefault
+
Dictionary mapping axolotl config keys to values.
+
required
+
+
+
cli_args
+
InferenceCliArgs
+
Inference-specific CLI arguments.
+
required
+
+
+
+
+
+
+
do_inference_gradio
+
cli.inference.do_inference_gradio(cfg, cli_args)
+
Runs inference in a Gradio interface. User input is accepted, a chat template is
+(optionally) applied, and the model specified in the axolotl config is used to
+generate completions according to a default generation config.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
cfg
+
DictDefault
+
Dictionary mapping axolotl config keys to values.
+
required
+
+
+
cli_args
+
InferenceCliArgs
+
Inference-specific CLI arguments.
+
required
+
+
+
+
+
+
+
get_multi_line_input
+
cli.inference.get_multi_line_input()
+
Gets multi-line input from terminal.
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
str
+
Possibly multi-line, possibly empty stdin input as a string.
Parses axolotl config, CLI args, and calls do_merge_lora. Note that various
+config values will be overwritten to allow the LoRA merge logic to work as expected
+(load_in_8bit=False, load_in4bit=False, flash_attention=False, etc.).
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
config
+
Union[Path, str]
+
Path to axolotl config YAML file.
+
Path('examples/')
+
+
+
kwargs
+
+
Additional keyword arguments to override config file values.
+
{}
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
ValueError
+
If target directory for LoRA merged model does not exist.
+
+
+
+
+
+
+
do_merge_lora
+
cli.merge_lora.do_merge_lora(cfg)
+
Calls transformers’ merge_and_unload on the model given in the axolotl config
+along with the LoRA adapters to combine them into a single base model.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
+SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if
+safe_serialization else pytorch_model.bin.
+
Note: this is a CPU-bound process.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
checkpoint_dir
+
str
+
The directory containing the FSDP checkpoints (can be either the model or optimizer).
+
required
+
+
+
output_path
+
str
+
The path to save the merged checkpoint.
+
required
+
+
+
safe_serialization
+
bool, optional, defaults to True
+
Whether to save the merged weights with safetensors (recommended).
+
False
+
+
+
remove_checkpoint_dir
+
bool, optional, defaults to False
+
Whether to remove the checkpoint directory after merging.
+
False
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
ValueError
+
If torch version < 2.3.0, or if checkpoint_dir does not exist.
Recursively generates all possible configurations by applying sweeps to the base config.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
base_config
+
dict
+
The original configuration dictionary
+
required
+
+
+
sweeps_config
+
dict
+
Dictionary where keys are parameters and values are either: - lists of values to sweep independently - or for paired values, a list of dicts under the ’_’ key
Parses axolotl config, CLI args, and calls do_train.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
config
+
Union[Path, str]
+
Path to axolotl config YAML file.
+
Path('examples/')
+
+
+
kwargs
+
+
Additional keyword arguments to override config file values.
+
{}
+
+
+
+
+
+
+
do_train
+
cli.train.do_train(cfg, cli_args)
+
Trains a transformers model by first loading the dataset(s) specified in the
+axolotl config, and then calling axolotl.train.train. Also runs the plugin
+manager’s post_train_unload once training completes.
Loads one or more training or evaluation datasets for RL training using paired
+preference data, calling axolotl.utils.data.rl.load_prepare_preference_datasets.
+Optionally, logs out debug information.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
cfg
+
DictDefault
+
Dictionary mapping axolotl config keys to values.
+
required
+
+
+
cli_args
+
Union[PreprocessCliArgs, TrainerCliArgs]
+
Command-specific CLI arguments.
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
TrainDatasetMeta
+
Dataclass with fields for training and evaluation datasets and the computed
Overwrite the push_to_hub method in order to force-add the tags when pushing the
+model on the Hub. Please refer to ~transformers.Trainer.push_to_hub for more details.
Overwrite the push_to_hub method in order to force-add the tags when pushing the
+model on the Hub. Please refer to ~transformers.Trainer.push_to_hub for more details.
Iterable dataset that returns constant length chunks of tokens from stream of text files.
+Args:
+tokenizer (Tokenizer): The processor used for processing the data.
+dataset (dataset.Dataset): Dataset with text files.
+seq_length (int): Length of token sequences to return.
Dataset that returns tokenized prompts from a stream of text files.
+Args:
+prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
+dataset (dataset.Dataset): Dataset with text files.
+process_count (int): Number of processes to use for tokenizing.
+keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
+Plugins can be used to integrate third-party models, modify the training process, or add new features.
+
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
The PluginManager class is responsible for loading and managing plugins.
+
+
+
+
+
BaseOptimizerFactory
+
integrations.base.BaseOptimizerFactory()
+
Base class for factories to create custom optimizers
+
+
+
BasePlugin
+
integrations.base.BasePlugin(self)
+
Base class for all plugins. Defines the interface for plugin methods.
+
Attributes:
+None
+
Methods:
+register(cfg): Registers the plugin with the given configuration.
+pre_model_load(cfg): Performs actions before the model is loaded.
+post_model_load(cfg, model): Performs actions after the model is loaded.
+pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
+post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
+create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
+create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
+add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
+add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
Parameters:
+cfg (dict): The configuration for the plugin.
+trainer (object): The trainer object for training.
+optimizer (object): The optimizer for training.
+
Returns:
+object: The created learning rate scheduler.
Parameters:
+cfg (dict): The configuration for the plugin.
+model (object): The loaded model.
+
Returns:
+None
+
+
+
pre_model_load
+
integrations.base.BasePlugin.pre_model_load(cfg)
+
Performs actions before the model is loaded.
+
Parameters:
+cfg (dict): The configuration for the plugin.
+
Returns:
+None
+
+
+
register
+
integrations.base.BasePlugin.register(cfg)
+
Registers the plugin with the given configuration.
+
Parameters:
+cfg (dict): The configuration for the plugin.
+
Returns:
+None
+
+
+
+
+
PluginManager
+
integrations.base.PluginManager()
+
The PluginManager class is responsible for loading and managing plugins.
+It should be a singleton so it can be accessed from anywhere in the codebase.
+
Attributes:
+plugins (ListBasePlugin): A list of loaded plugins.
+
Methods:
+get_instance(): Static method to get the singleton instance of PluginManager.
+register(plugin_name: str): Registers a new plugin by its name.
+pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
+
Parameters:
+cfg (dict): The configuration for the plugins.
+trainer (object): The trainer object for training.
+optimizer (object): The optimizer for training.
+
Returns:
+object: The created learning rate scheduler, or None if none was found.
The plugin name should be in the format “module_name.class_name”.
+This function splits the plugin name into module and class, imports the module,
+retrieves the class from the module, and creates an instance of the class.
+
Parameters:
+plugin_name (str): The name of the plugin to be loaded. The name should be in the format “module_name.class_name”.
+
Returns:
+BasePlugin: An instance of the loaded plugin.
+
Raises:
+ImportError: If the plugin module cannot be imported.
Fast NF4 dequantization using bitsandbytes CUDA kernels.
+
Performs efficient dequantization of weights from NF4 format using bitsandbytes’
+optimized CUDA implementations. Supports both legacy list and new QuantState
+formats.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
W
+
torch.Tensor
+
Quantized weight tensor to dequantize
+
required
+
+
+
quant_state
+
QuantState | list | None
+
Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged.
+
None
+
+
+
out
+
torch.Tensor | None
+
Optional output tensor for storing dequantized results. Must match expected shape and dtype if provided.
+
None
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
torch.Tensor
+
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
+
+
+
+
torch.Tensor
+
input W was transposed.
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
AssertionError
+
If provided output tensor doesn’t match expected shape / dtype.
+
+
+
+
+
+
Note
+
Uses CUDA streams for better performance when available in newer bitsandbytes
+versions (>0.43.3).
Mllama flash cross-attention module. This module inherits from MllamaTextCrossAttention and
+implements the forward pass using Flash Attention for improved performance.
Mllama flash self-attention module. This module inherits from MllamaTextSelfAttention and
+implements the forward pass using Flash Attention for improved performance.
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
Applies optimized Triton kernel patches to a PEFT model.
+
Patches a PEFT model with optimized implementations for MLP and attention
+computations. The optimizations include custom Triton kernels for activation
+functions and specialized autograd functions for LoRA computations.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
model
+
PeftModelForCausalLM
+
A PEFT model to be patched with optimized kernels.
+
required
+
+
+
cfg
+
DictDefault
+
Dictionary mapping axolotl config keys to values.
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
PeftModelForCausalLM
+
PeftModelForCausalLM
+
The patched model with optimized kernels.
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
TypeError
+
If the provided model is not a PeftModelForCausalLM.
+
+
+
+
NotImplementedError
+
If the model type is not supported.
+
+
+
+
AssertionError
+
If multiple adapters are active (currently unsupported).
+
+
+
+
+
+
Note
+
The optimizations require LoRA adapters with no dropout and no bias terms. The
+function will skip patching if these conditions aren’t met.
Get the appropriate attention class by inspecting the model config.
+Uses dynamic import to support any model architecture that follows
+the standard transformers naming convention.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
cfg
+
DictDefault
+
Dictionary mapping axolotl config keys to values.
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
Type[nn.Module]
+
The appropriate attention class for the model.
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
ValueError
+
If base_model not specified or attention class cannot be imported
+
+
+
+
ImportError
+
If the model module or attention class doesn’t exist
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
PEFT usually casts the layer norms in float32 for training stability reasons
+therefore the input hidden states gets silently casted in float32. Hence, we need
+cast them back in float16 / bfloat16 just to be sure everything works as expected.
+This might slowdown training & inference so it is recommended to not cast the LayerNorms!
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
query
+
torch.Tensor
+
Input query states to be passed to Flash Attention API
+
required
+
+
+
key
+
torch.Tensor
+
Input key states to be passed to Flash Attention API
+
required
+
+
+
value
+
torch.Tensor
+
Input value states to be passed to Flash Attention API
+
required
+
+
+
target_dtype
+
torch.dtype, optional
+
The dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype.
+
None
+
+
+
preferred_dtype
+
torch.dtype, optional
+
The preferred dtype to convert the attention tensors to regardless of the target dtype.
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].
+This expansion handles packed sequences so that sequences share the same attention mask integer value
+when they attend to each other within that sequence.
+This expansion transforms the mask to lower triangular form to prevent future peeking.
Prompt Strategy for finetuning Llama2 chat models
+see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
+
This implementation is based on the Vicuna PR and the fastchat repo, see also:
+https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
+
Use dataset type: “llama2_chat” in conig.yml to use this prompt style.
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
+
in a jsonl file. The first message should be from the human, the second from gpt.
+For a custom system message, the first “from” can be “system” (followed by alternating “human” and “gpt” turns).
+
Important: Don’t use “special_tokens:” in your config.yml if you are not sure what you are doing!
Tokenizing strategy for Llama2 prompts.
+adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
+
+
+
Llama2ChatConversation
+
prompt_strategies.llama2_chat.Llama2ChatConversation(
+self,
+ name='llama2',
+ system="[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
+ roles=('[INST]', '[/INST]'),
+ messages=list(),
+ offset=0,
+)
+
A class that manages prompt templates and keeps all conversation history.
+copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Module for stepwise datasets, typically including a prompt and reasoning traces,
+and (optionally) per-step, or per-prompt-trace labels for reward modelling.
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
+These datasets should include the following columns:
+- prompt: the prompt text
+- completions: a list of n completion steps
+- labels: a list of n labels indicating the “correctness” of each step
wraps a function and returns the default value instead of running the
+
+
+
+
+
check_cuda_device
+
utils.bench.check_cuda_device(default_value)
+
wraps a function and returns the default value instead of running the
+wrapped function if cuda isn’t available or the device is auto
+:param default_value:
+:return:
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
+This is a custom variant that doesn’t re-tokenize the input or re-load the model.
This module provides functionality for selecting chat templates based on user choices.
+These templates are used for formatting messages in a conversation.
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
tokenizer
+
[PreTrainedTokenizer] or [PreTrainedTokenizerFast]
+
The tokenizer used for encoding the data.
+
required
+
+
+
model
+
[PreTrainedModel]
+
The model that is being trained. If set and has the prepare_decoder_input_ids_from_labels, use it to prepare the decoder_input_ids This is useful when using label_smoothing to avoid calculating loss twice.
+
None
+
+
+
padding
+
bool, str or [~utils.PaddingStrategy], optional, defaults to True
+
Select a strategy to pad the returned sequences (according to the model’s padding side and padding index) among: - True or 'longest' (default): Pad to the longest sequence in the batch (or no padding if only a single sequence is provided). - 'max_length': Pad to a maximum length specified with the argument max_length or to the maximum acceptable input length for the model if that argument is not provided. - False or 'do_not_pad': No padding (i.e., can output a batch with sequences of different lengths).
+
True
+
+
+
max_length
+
int, optional
+
Maximum length of the returned list and optionally padding length (see above).
+
None
+
+
+
pad_to_multiple_of
+
int, optional
+
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+
None
+
+
+
label_pad_token_id
+
int, optional, defaults to -100
+
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
+
-100
+
+
+
return_tensors
+
str
+
The type of Tensor to return. Allowable values are “np”, “pt” and “tf”.
+
'pt'
+
+
+
sequence_parallel_degree
+
int
+
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
Adjust position IDs for a sliced sequence to maintain proper relative positions.
+This handles the case where position IDs might not be contiguous due to sample
+packing.
Context manager that only runs the enclosed block on the main rank.
+
+
+
+
+
barrier
+
utils.distributed.barrier()
+
Acts as a barrier to wait for all processes. This ensures that all processes
+reach the barrier before proceeding further.
+
+
+
compute_and_broadcast
+
utils.distributed.compute_and_broadcast(fn)
+
Compute a value using the function ‘fn’ only on the specified rank (default is 0).
+The value is then broadcasted to all other ranks.
+
Args:
+- fn (callable): A function that computes the value. This should not have any side effects.
+- rank (int, optional): The rank that computes the value. Default is 0.
Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
+
Args:
+- fn (callable): A function that computes the value. This should not have any side effects.
+- rank (int, optional): The rank that gathers the values. Default is 0.
+- world_size (int, optional): Total number of processes in the current distributed setup.
+
Returns:
+- A list of computed values from all ranks if on the gathering rank, otherwise None.
Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
+
Args:
+- fn (callable): A function that computes the value. This should not have any side effects.
+- rank (int, optional): The rank that gathers the values. Default is 0.
+- world_size (int, optional): Total number of processes in the current distributed setup.
+
Returns:
+- A list of computed values from all ranks if on the gathering rank, otherwise None.
+
+
+
is_distributed
+
utils.distributed.is_distributed()
+
Check if distributed training is initialized.
+
+
+
is_main_process
+
utils.distributed.is_main_process()
+
Check if the current process is the main process.
+If not in distributed mode, always return True.
+
+
+
reduce_and_broadcast
+
utils.distributed.reduce_and_broadcast(fn1, fn2)
+
Run a callable ‘fn1’ on all ranks, gather the results, reduce them using ‘fn2’,
+and then broadcast the reduced result to all ranks.
+
Args:
+- fn1 (callable): A function that computes the value on each rank.
+- fn2 (callable): A reduction function that takes a list of values and returns a single value.
+- world_size (int, optional): Total number of processes in the current distributed setup.
+
Returns:
+- The reduced and broadcasted value.
+
+
+
zero_first
+
utils.distributed.zero_first(is_main)
+
runs the wrapped context so that rank 0 runs first before other ranks
+
+
+
zero_only
+
utils.distributed.zero_only()
+
Context manager that only runs the enclosed block on the main rank.
Freezes all layers of the given model except for the layers that match given regex patterns.
+Periods in the patterns are treated as literal periods, not as wildcard characters.
+
Parameters:
+- model (nn.Module): The PyTorch model to be modified.
+- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
+Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
+Also, to match the entire layer name, the pattern should start with “^” and end with “\(", otherwise it will match any part of the layer name.
+ The range pattern part is optional and it is not compiled as a regex pattern which means you must put "\)” before the range pattern if you want to match the entire layer name.
+E.g., [“^model.embed_tokens.weight\([:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+\)”]
Set self.auto_model_loader. Defaults to transformers.AutoModelForCausalLM
+(set at __init__). When using a multimodal model, self.auto_model_loader
+should be set according to the type of the model.
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
+Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
Create a schedule with a learning rate that decreases following the values of the cosine function between the
+initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+initial lr set in the optimizer.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
optimizer
+
[~torch.optim.Optimizer]
+
The optimizer for which to schedule the learning rate.
+
required
+
+
+
num_warmup_steps
+
int
+
The number of steps for the warmup phase.
+
required
+
+
+
num_training_steps
+
int
+
The total number of training steps.
+
required
+
+
+
num_cycles
+
float, optional, defaults to 0.5
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine).
+
0.5
+
+
+
last_epoch
+
int, optional, defaults to -1
+
The index of the last epoch when resuming training.
+
-1
+
+
+
+
+
+
Return
+
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
+Create a schedule with a learning rate that decreases following the values of the cosine function between the
+initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
+, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
optimizer
+
[~torch.optim.Optimizer]
+
The optimizer for which to schedule the learning rate.
+
required
+
+
+
num_warmup_steps
+
int
+
The number of steps for the warmup phase.
+
required
+
+
+
num_training_steps
+
int
+
The total number of training steps.
+
required
+
+
+
constant_lr_ratio
+
float
+
(float): The ratio of num_training_steps to decrease by cosine function.
+
required
+
+
+
min_lr_ratio
+
float
+
(float): The ratio of maximum learning rate for cosine function to decay to minimum learning rate. | _required_ | | num_cycles |float, *optional*, defaults to 0.5 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). |0.5| | last_epoch |int, *optional*, defaults to -1 | The index of the last epoch when resuming training. |-1`
+
+
+
+
+
+
+
Return
+
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.
Handle backwards compatibility between legacy message field mapping and new property mapping system.
+
Previously, the config only supported mapping ‘role’ and ‘content’ fields via dedicated config options:
+- message_field_role: Mapped to the role field
+- message_field_content: Mapped to the content field
+
The new system uses message_property_mappings to support arbitrary field mappings:
+message_property_mappings:
+role: source_role_field
+content: source_content_field
+additional_field: source_field
+
+
Parameters
+
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
Default
+
+
+
+
+
data
+
dict
+
Dictionary containing configuration data
+
required
+
+
+
+
+
+
Returns
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
dict
+
Updated dictionary with message field mappings consolidated
+
+
+
+
+
+
Raises
+
+
+
+
+
+
+
+
+
Name
+
Type
+
Description
+
+
+
+
+
+
ValueError
+
If there are conflicts between legacy and new mappings
use the PoSE technique to extend the context length by randomly skipping
+positions in the context. We only want to skip right before tokens in
+the split_on_token_ids list. We should attempt to randomly distribute
+the skips, but we don’t need the final position_ids to be the full
+context_len. There may be multiple turns in the context, so we want to
+make sure we take into account the maximum possible number of skips
+remaining in each sample.
+
+
+
add_position_ids
+
utils.trainer.add_position_ids(sample)
+
Handle both single-example and batched data.
+- single example: sample[‘input_ids’] is a list[int]
+- batched data: sample[‘input_ids’] is a list[list[int]]
@@ -424,7 +430,11 @@ ul.task-list li input[type="checkbox"] {
Memory Consumption with Batch Size: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
Gradient Accumulation: With gradient accumulation, you’re effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you’re only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
-
Example 1: Micro batch size: 3 Gradient accumulation steps: 2 Number of GPUs: 3 Total batch size = 3 * 2 * 3 = 18
+ Train with long sequences split across multiple GPUs.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Sequence Parallelism
+
Sequence parallelism is a technique that splits sequences across multiple GPUs,
+allowing you to train with very long sequences that wouldn’t fit on a single GPU. Each
+GPU processes a different portion of the sequence, and the results are aggregated
+through a ring communication pattern.
+
+
When to Use Sequence Parallelism
+
Use sequence parallelism when:
+
+
You need to train with sequence lengths that don’t fit into a single GPU’s memory
+
You have multiple GPUs available
+
You’re experiencing OOM (Out Of Memory) errors with long sequences
+
+
+
+
Configuration
+
To enable sequence parallelism, add the following to your configuration file:
+
# Set to a divisor (> 1) of the number of GPUs available
+sequence_parallel_degree:4 # Split sequences across 4 GPUs
+
The sequence_parallel_degree should be a divisor of the total number of GPUs. For example:
+
+
With 8 GPUs, valid values would be 2, 4, or 8
+
With 4 GPUs, valid values would be 2 or 4
+
+
+
+
Implementation Details
+
When sequence parallelism is enabled:
+
+
Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
+
The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
+
Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
+
The trainer uses special ring communication patterns for attention operations
+
+
+
+
Requirements
+
To use sequence parallelism, you need:
+
+
Multiple GPUs (at least 2)
+
The ring-flash-attn package. Install with:
+
+
pip install axolotl[ring-flash-attn] (preferred)
+
pip install ring-flash-attn>=0.1.4
+
+
+
+
+
Limitations
+
+
Flash attention must be enabled for this to work (flash_attention: true in config YAML)
+
May have a small performance overhead due to communication between GPUs
+
+
+
+
Example
+
# Example config with sequence parallelism
+base_model: meta-llama/Llama-3-8B-Instruct
+sequence_len:8192
+sequence_parallel_degree:2 # Split each sequence into 4 parts
+flash_attention:true # Required with sequence parallelism
+...
+
This will train the Llama 3 8B model with 8K context length, with each sequence split
+into 2 subsequences of length 4096 across 2 GPUs.
+
+
+
Sample Packing with Sequence Parallelism
+
Sequence parallelism is compatible with Axolotl’s sample packing functionality. When using both features together:
+
+
Samples are first packed together
+
The packed sequences are then divided across GPUs in the sequence parallel group
+
Position IDs are automatically adjusted to maintain proper relative positions
+
+
+
+
Effect on Batch Size
+
When using sequence parallelism, your effective global batch size is divided by the sequence_parallel_degree. This happens because:
+
+
Each group of sequence_parallel_degree GPUs works on the same batch (just different parts of each sequence)
+
The number of batches processed per step decreases
+
+
For example:
+- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
+- With 8 GPUs and sequence_parallel_degree=4: Only 2 different batches processed per step (each split across 4 GPUs)
+- If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4