Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
fc1900761b fix(trl): remove access to invalid property 2025-05-02 15:41:53 +07:00
Wing Lian
bcb59c70e2 automatically set pad_to_sequence_len when use packing (#2607)
* automatically set pad_to_sequence_len when use packing

* update tests
2025-05-01 13:24:38 -04:00
NanoCode012
6a3e6f8c53 fix: run preview-docs only when md/qmd changes (#2606)
* fix: run preview-docs only when md/qmd changes

* feat: add quarto yaml based on PR feedback
2025-05-01 13:21:28 -04:00
5 changed files with 40 additions and 15 deletions

View File

@@ -4,6 +4,12 @@ on:
pull_request: pull_request:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
# Run the workflow only when one of these files changes
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
permissions: permissions:
checks: write checks: write
contents: write contents: write

View File

@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id: if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:] res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id: if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:] res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res return res

View File

@@ -55,16 +55,13 @@ def dequantize(
target_device = W.device target_device = W.device
# Extract quantization state # Extract quantization state
nested = False
if not isinstance(quant_state, list): if not isinstance(quant_state, list):
# New style quant_state class # New style quant_state class
absmax = quant_state.absmax.to(target_device) absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape shape = quant_state.shape
dtype = quant_state.dtype dtype = quant_state.dtype
blocksize = quant_state.blocksize blocksize = quant_state.blocksize
if quant_state.nested: offset = quant_state.offset.to(target_device)
nested = True
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2 state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device) absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device) code2 = state2.code.to(target_device)
@@ -118,8 +115,7 @@ def dequantize(
ctypes.c_int(n_elements_absmax), ctypes.c_int(n_elements_absmax),
) )
if nested: out_absmax += offset
out_absmax += offset
# Choose appropriate dequantization function # Choose appropriate dequantization function
fx = ( fx = (

View File

@@ -512,10 +512,17 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def hint_sample_packing_padding(cls, data): def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"): if data.get("sample_packing"):
LOG.warning( pad_to_sequence_len = data.get("pad_to_sequence_len")
"`pad_to_sequence_len: true` is recommended when using sample_packing" if pad_to_sequence_len is False:
) LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault( DictDefault(
{ {
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": None, "pad_to_sequence_len": False,
"flash_attention": True, "flash_attention": True,
} }
) )
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
for record in self._caplog.records for record in self._caplog.records
) )
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg): def test_merge_lora_no_bf16_fail(self, minimal_cfg):
""" """
This is assumed to be run on a CPU machine, so bf16 is not supported. This is assumed to be run on a CPU machine, so bf16 is not supported.