Compare commits

...

81 Commits

Author SHA1 Message Date
Sung Ching Liu
328bb0466b Merge branch 'main' into flx_attn_support 2025-02-21 11:27:25 -05:00
Sunny Liu
e792b54bab remove unnecessary components 2025-02-21 11:23:21 -05:00
NanoCode012
bf842730a5 fix(doc): add missing auto_find_batch_size (#2339) [skip ci] 2025-02-21 11:56:38 +07:00
Wing Lian
1db6ad60a7 support for passing init_lora_weights to lora_config (#2352) 2025-02-20 22:56:34 -05:00
salman
29b366b2e1 Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344)
* bumping TRL version

* apply upstream fixes to our custom fix

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-02-20 22:56:04 -05:00
NanoCode012
b53a41372f feat: update transformers version to 4.49.0 (#2340) 2025-02-20 21:12:06 -05:00
Wing Lian
02f45e94be calculate sample length fixes and SFT splitting fixes (#2351)
* fix chat template splitting long samples across multiple rows

* make the preprocessing faster
2025-02-20 14:29:58 -05:00
bursteratom
82d04ea060 test v2batch w/ flex attn 2025-02-13 00:11:45 -05:00
Sung Ching Liu
0ef1f011fe Merge branch 'main' into flx_attn_support 2025-02-11 23:31:56 -05:00
Sunny Liu
c0a1d205c7 packed doc mask starts at 1, 0 means masked out 2025-02-07 14:44:52 -05:00
Sunny Liu
d0e739da24 attempt at getting around bf16 error 2025-02-04 21:57:21 -05:00
Sunny Liu
3f6be519d5 stack 2025-02-04 21:25:13 -05:00
Sunny Liu
adcbc7459b misc 2025-02-04 21:17:50 -05:00
Sunny Liu
470ba65c44 make doc mask instead of the whole block mask in collator 2025-02-04 20:27:39 -05:00
Sunny Liu
8e1adc154d stuff 2025-02-02 20:36:14 -05:00
Sunny Liu
e5b36900e4 misc 2025-02-02 20:32:03 -05:00
Sunny Liu
9f6c89b12b undo my stupidity 2025-02-02 20:25:53 -05:00
Sunny Liu
b0871c8d3b attempt - mask padding 2025-02-02 20:18:49 -05:00
bursteratom
d3ea379a23 figure out slight diff from flash result 2025-02-02 01:45:54 -05:00
bursteratom
0ebab63309 test 2025-02-02 01:27:15 -05:00
bursteratom
e98581f6f5 BLOCK SIZE 2025-02-02 01:22:23 -05:00
bursteratom
b832b11c8f stuff 2025-02-02 00:51:43 -05:00
bursteratom
b692d394b1 more test 2025-02-02 00:48:57 -05:00
bursteratom
2319e5276d more test 2025-02-02 00:48:15 -05:00
bursteratom
9a43a0925d more test 2025-02-02 00:45:30 -05:00
bursteratom
10de67e8ea more test 2025-02-02 00:43:41 -05:00
bursteratom
fa7355404c test 2025-02-02 00:38:35 -05:00
bursteratom
907424a2e8 stuff 2025-02-02 00:29:09 -05:00
Sunny Liu
3f4fd3c1eb remove padding self attention 2025-02-01 22:47:10 -05:00
Sunny Liu
48c3c47071 vanills mask 2025-02-01 14:23:37 -05:00
Sunny Liu
3ed9c117fb try vanilla mask 2025-02-01 14:09:13 -05:00
Sunny Liu
84960003ed reset llama_patch_multipack.py 2025-01-30 14:40:18 -05:00
Sunny Liu
93a268e43d --no-verify
fixes silly mistake
2025-01-30 14:08:26 -05:00
Sunny Liu
065f6d477e flex batching WIP 2025-01-30 14:04:59 -05:00
Sunny Liu
96ad741cd5 flex batching WIP 2025-01-30 12:35:25 -05:00
bursteratom
ba88bc7840 wip flex block mask creation 2025-01-29 00:25:25 -05:00
Sung Ching Liu
b31796a681 Merge branch 'main' into flx_attn_support 2025-01-28 14:20:43 -05:00
Sunny Liu
5ca57cb55a undo bool conversion 2025-01-23 17:56:13 -05:00
Sunny Liu
0149de7fb0 mask to bool 2025-01-23 15:30:08 -05:00
Sunny Liu
8c34c65181 dummy 2025-01-23 14:56:26 -05:00
Sunny Liu
555aa5772a skip mask conversion if already 4d 2025-01-23 14:01:53 -05:00
Sunny Liu
e8b2789086 revert mask expand 2025-01-23 11:20:38 -05:00
Sunny Liu
85752cdfc9 mask expansion 2025-01-22 21:33:38 -05:00
Sunny Liu
f2f23c8041 mask expansion 2025-01-22 21:31:42 -05:00
Sunny Liu
8b3eec7f6e mask expansion 2025-01-22 21:29:52 -05:00
Sunny Liu
bb9bea3110 mask expansion 2025-01-22 21:27:25 -05:00
Sunny Liu
0dd18a3681 llama sdpa patching WIP - static class function import 2025-01-22 21:10:05 -05:00
Sunny Liu
152e988d3c llama sdpa patching WIP - static class function import 2025-01-22 21:02:26 -05:00
Sunny Liu
27532825a9 llama sdpa patching WIP - static class function import 2025-01-22 21:00:34 -05:00
Sunny Liu
06f83a54a5 llama sdpa patching WIP - static class function import 2025-01-22 20:45:44 -05:00
Sunny Liu
d7b133dc1f llama sdpa patching WIP - static class function import 2025-01-22 20:33:13 -05:00
Sunny Liu
f3bec17917 llama sdpa patching WIP - static class function import 2025-01-22 20:25:26 -05:00
Sunny Liu
b7deb5241c llama sdpa patching WIP 2025-01-22 20:16:27 -05:00
Sunny Liu
cee310dcfa llama sdpa patching WIP 2025-01-22 20:15:23 -05:00
Sunny Liu
d1be6e228d llama sdpa patching WIP 2025-01-22 20:14:20 -05:00
Sunny Liu
5f9f77f384 llama patch 2025-01-22 11:29:28 -05:00
bursteratom
b2a34380b3 sample packing doc mask creation WIP 2025-01-21 09:18:38 -05:00
Sunny Liu
80bfc50d1f get seqlens from position ids for foc masking 2025-01-17 17:22:04 -05:00
Sunny Liu
a5360c172c llama hijacking 2025-01-17 15:54:03 -05:00
Sunny Liu
013a9b73fc fix transformers version for testing 2025-01-16 15:32:57 -05:00
Sunny
aad62428e0 not sure if this is necessary actually 2025-01-16 15:08:34 -05:00
Sunny
a6f2c5d583 flex sample packing WIP 2025-01-15 21:12:33 -05:00
Sunny
dbcd11e533 revert seq len in multipack sampler 2025-01-14 11:45:35 -05:00
Sunny
c06a6be915 flex_attn sample packing WIP 2025-01-14 00:22:05 -05:00
bursteratom
d3a0cb5edb transformers version 2025-01-13 10:33:00 -05:00
bursteratom
8b47e456b0 revert to transformers 4.47.1 2025-01-13 10:29:27 -05:00
Sunny Liu
2319ac729c Merge branch 'main' into flx_attn_support 2025-01-13 09:42:58 -05:00
Sunny
f99cae0e7b llama test 2025-01-12 17:30:19 -05:00
Wing Lian
888cd9407f use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-12 13:34:17 -05:00
Wing Lian
bd62d6e10a rename liger test so it properly runs in ci (#2246) 2025-01-12 13:34:17 -05:00
NanoCode012
5eae134110 feat: add support for data_files in pretraining (#2238) 2025-01-12 13:34:17 -05:00
Wing Lian
b7d27bdfa4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-12 13:34:17 -05:00
Vincenzo di Cicco
da97a21bdc Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-12 13:34:17 -05:00
Wing Lian
e0d4b88598 update modal version for ci (#2242) 2025-01-12 13:34:17 -05:00
NanoCode012
fac059a209 fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-12 13:34:17 -05:00
Wing Lian
9c9ac1cf0b add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-12 13:34:17 -05:00
Wing Lian
2346f21b2b Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-12 13:34:17 -05:00
salman
0b47281f51 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-12 13:34:17 -05:00
Sunny
543daaf46f llama test 2025-01-09 16:08:24 -05:00
Sunny
bcd9ad44e0 flex attention support 2025-01-06 19:54:11 -05:00
bursteratom
61ad375bf4 config validation for flex attention 2025-01-05 23:27:49 -05:00
11 changed files with 165 additions and 23 deletions

View File

@@ -407,7 +407,10 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second:
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.3
transformers==4.49.0
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.0
trl==0.15.1
optimum==1.16.2
hf_transfer

View File

@@ -831,7 +831,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -78,7 +78,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
@@ -100,8 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,

View File

@@ -95,6 +95,103 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -176,7 +273,10 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
res[key].append(val)
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -342,6 +342,7 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
@@ -822,6 +823,7 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
@@ -1788,6 +1790,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):

View File

@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and self.cfg.flash_attention
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
@@ -707,7 +707,13 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flash_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1113,7 +1119,7 @@ class ModelLoader:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)
@@ -1321,6 +1327,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset):
def get_dataset_lengths(dataset, from_arrow=False):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths