Compare commits

..

1 Commits

Author SHA1 Message Date
Dan Saunders
02efd7e83d quick formatting fix for LoRA optims doc 2025-02-19 14:17:20 +00:00
11 changed files with 23 additions and 165 deletions

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. # bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool] include_tokens_per_second:
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.49.0 transformers==4.48.3
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.15.1 trl==0.15.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -831,9 +831,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs: if "max_length" in kwargs:
kwargs.pop("max_length") kwargs.pop("max_length")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.flex_attention is True: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
elif ( elif (
self.cfg.model_config_type in ["llama"] self.cfg.model_config_type in ["llama"]

View File

@@ -78,6 +78,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
if is_peft_model(unwrapped_model): if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter() unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes # Remove base_model and base_layer prefixes
state_dict = { state_dict = {
k.removeprefix("base_model.model.") k.removeprefix("base_model.model.")
@@ -99,10 +100,8 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
} }
else: else:
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
llm_model = ( llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model self.llm.llm_engine.model_executor.driver_worker.model_runner.model
) )
llm_model.load_weights(state_dict.items()) llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
**_kwargs, **_kwargs,
): ):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0: if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, args.output_dir,

View File

@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids): def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids""" """generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1: if len(position_ids.shape) == 1:
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence. when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking. This expansion transforms the mask to lower triangular form to prevent future peeking.
""" """
bsz, src_len = mask.size()
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2) mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row)) dict(zip(feature_names, row))
) )
for key, val in tokenized_prompt.items(): for key, val in tokenized_prompt.items():
res[key].append(val) for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary # If there are no examples left, return an empty dictionary
if not res: if not res:

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field( qlora_sharded_model_loading: Optional[bool] = Field(
default=False, default=False,
@@ -823,7 +822,6 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None flash_attn_rms_norm: Optional[bool] = None
@@ -1790,26 +1788,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_torch_compile_auto(cls, data): def check_torch_compile_auto(cls, data):

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
) )
try: try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(get_dataset_lengths(dataset))
min_input_len = np.min(ds_lengths) LOG.debug(f"min_input_len: {min_input_len}")
LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(get_dataset_lengths(dataset))
max_input_len = np.max(ds_lengths) LOG.debug(f"max_input_len: {max_input_len}")
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError: except AttributeError:
pass pass

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if ( if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
if "auto_map" in self.model_config: if "auto_map" in self.model_config:
@@ -707,13 +707,7 @@ class ModelLoader:
""" """
sample packing uses custom FA2 patch sample packing uses custom FA2 patch
""" """
if self.cfg.flash_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1119,7 +1113,7 @@ class ModelLoader:
should_convert = ( should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp) ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
) )
@@ -1327,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits: if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq" lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora: if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.") LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np import numpy as np
def get_dataset_lengths(dataset, from_arrow=False): def get_dataset_lengths(dataset):
if "length" in dataset.column_names: if "length" in dataset.column_names:
lengths = np.array(dataset["length"]) lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names: elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"] position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids]) lengths = np.array([x[-1] + 1 for x in position_ids])
else: else:
if from_arrow: input_ids = dataset["input_ids"]
input_ids = dataset.data.column("input_ids") lengths = np.array([len(seq) for seq in input_ids])
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths return lengths