Compare commits

..

1 Commits

Author SHA1 Message Date
Sung Ching Liu
90dfcd8c03 Revert "Fix sample packing producing longer sequences than specified by `sequ…"
This reverts commit 8dfadc2b3c.
2025-02-19 21:13:25 -05:00
12 changed files with 33 additions and 175 deletions

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
include_tokens_per_second:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.49.0
transformers==4.48.3
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.1
trl==0.15.0
optimum==1.16.2
hf_transfer

View File

@@ -831,9 +831,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -78,6 +78,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
@@ -99,10 +100,8 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,

View File

@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
res[key].append(val)
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
@@ -823,7 +822,6 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
@@ -1790,26 +1788,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.flash_attention
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
@@ -707,13 +707,7 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1119,7 +1113,7 @@ class ModelLoader:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)
@@ -1327,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset, from_arrow=False):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths

View File

@@ -7,7 +7,6 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -19,6 +18,11 @@ def fixture_tokenizer():
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
@@ -33,7 +37,6 @@ class TestBatchedSamplerPacking:
(2, 2),
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
@@ -59,9 +62,6 @@ class TestBatchedSamplerPacking:
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
batch_idxs.extend(pack)
for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert len(batch["input_ids"]) <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset)))