Compare commits
1 Commits
flx_attn_s
...
grpo-ref-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9ebff087c |
@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
|
|||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||||
include_tokens_per_second: # Optional[bool]
|
include_tokens_per_second:
|
||||||
|
|
||||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
|
||||||
auto_find_batch_size: # Optional[bool]
|
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.0
|
|||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.15.1
|
trl==0.15.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -831,9 +831,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.flex_attention is True:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||||
# pylint: enable=access-member-before-definition
|
# pylint: enable=access-member-before-definition
|
||||||
|
|
||||||
|
# cleanup the ref_model if we have a peft model passed in
|
||||||
|
# TODO remove this after next major trl release
|
||||||
|
if (
|
||||||
|
self.ref_model # pylint: disable=access-member-before-definition
|
||||||
|
and is_peft_model(self.model)
|
||||||
|
):
|
||||||
|
del self.ref_model
|
||||||
|
self.ref_model = None
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
def _enable_gradient_checkpointing(
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
self, model: PreTrainedModel, args: GRPOConfig
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
@@ -78,6 +87,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
if is_peft_model(unwrapped_model):
|
if is_peft_model(unwrapped_model):
|
||||||
unwrapped_model.merge_adapter()
|
unwrapped_model.merge_adapter()
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
unwrapped_model.unmerge_adapter()
|
||||||
# Remove base_model and base_layer prefixes
|
# Remove base_model and base_layer prefixes
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.removeprefix("base_model.model.")
|
k.removeprefix("base_model.model.")
|
||||||
@@ -99,10 +109,8 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
llm_model = (
|
llm_model = (
|
||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
)
|
)
|
||||||
llm_model.load_weights(state_dict.items())
|
llm_model.load_weights(state_dict.items())
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.unmerge_adapter()
|
|
||||||
|
|||||||
@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
def get_packed_mask_from_pos_ids(position_ids):
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for i, row in enumerate(position_ids):
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
|
|
||||||
for i, seq_len in enumerate(seq_lengths):
|
|
||||||
start_id = start_indices[i]
|
|
||||||
doc_mask[start_id : start_id + seq_len] = (
|
|
||||||
(i+1) * doc_mask[start_id : start_id + seq_len]
|
|
||||||
)
|
|
||||||
if padding_length:
|
|
||||||
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
|
||||||
|
|
||||||
results.append(doc_mask)
|
|
||||||
|
|
||||||
return torch.stack(results)
|
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_from_pos_ids(position_ids):
|
|
||||||
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
max_seq_len = position_ids.shape[1]
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
totalseqlens = []
|
|
||||||
|
|
||||||
for row in position_ids:
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
if padding_length:
|
|
||||||
seq_lengths = torch.cat(
|
|
||||||
[
|
|
||||||
seq_lengths,
|
|
||||||
torch.tensor(
|
|
||||||
[len(row) - torch.sum(seq_lengths)],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(seq_lengths)
|
|
||||||
totalseqlens.append(len(adjusted_row))
|
|
||||||
|
|
||||||
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
|
|||||||
when they attend to each other within that sequence.
|
when they attend to each other within that sequence.
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
if len(mask.size()) == 4:
|
|
||||||
return mask
|
|
||||||
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|||||||
@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -823,7 +822,6 @@ class AxolotlInputConfig(
|
|||||||
xformers_attention: Optional[bool] = None
|
xformers_attention: Optional[bool] = None
|
||||||
sdp_attention: Optional[bool] = None
|
sdp_attention: Optional[bool] = None
|
||||||
s2_attention: Optional[bool] = None
|
s2_attention: Optional[bool] = None
|
||||||
flex_attention: Optional[bool] = None
|
|
||||||
flash_attention: Optional[bool] = None
|
flash_attention: Optional[bool] = None
|
||||||
flash_attn_cross_entropy: Optional[bool] = None
|
flash_attn_cross_entropy: Optional[bool] = None
|
||||||
flash_attn_rms_norm: Optional[bool] = None
|
flash_attn_rms_norm: Optional[bool] = None
|
||||||
@@ -1790,26 +1788,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_flex_torch_version(cls, data):
|
|
||||||
if (data.get("flex_attention") is not None) and (
|
|
||||||
data.get("flex_attention") is True
|
|
||||||
):
|
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
|
||||||
torch_version = env_capabilities.get("torch_version")
|
|
||||||
|
|
||||||
if torch_version is None:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
|
||||||
|
|
||||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
|
||||||
raise ValueError(
|
|
||||||
"Flex attention is not supported on torch version < 2.5.1"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_auto(cls, data):
|
def check_torch_compile_auto(cls, data):
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if "auto_map" in self.model_config:
|
if "auto_map" in self.model_config:
|
||||||
@@ -707,13 +707,7 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
|
if self.cfg.flash_attention:
|
||||||
if self.cfg.flex_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flex_attention"
|
|
||||||
)
|
|
||||||
elif self.cfg.flash_attention:
|
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
@@ -1119,7 +1113,7 @@ class ModelLoader:
|
|||||||
should_convert = (
|
should_convert = (
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
|
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1327,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
if loftq_bits:
|
if loftq_bits:
|
||||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
if cfg.peft_init_lora_weights:
|
|
||||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
|
||||||
if cfg.peft_use_dora:
|
if cfg.peft_use_dora:
|
||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||||
|
|||||||
Reference in New Issue
Block a user