Compare commits
3 Commits
lora-quant
...
fix/dpo-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc1900761b | ||
|
|
bcb59c70e2 | ||
|
|
6a3e6f8c53 |
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -4,6 +4,12 @@ on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yaml'
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
|
||||
@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -55,16 +55,13 @@ def dequantize(
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
nested = False
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
if quant_state.nested:
|
||||
nested = True
|
||||
offset = quant_state.offset.to(target_device)
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
@@ -118,8 +115,7 @@ def dequantize(
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
if nested:
|
||||
out_absmax += offset
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
|
||||
@@ -512,10 +512,17 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_sample_packing_padding(cls, data):
|
||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
if data.get("sample_packing"):
|
||||
pad_to_sequence_len = data.get("pad_to_sequence_len")
|
||||
if pad_to_sequence_len is False:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
elif pad_to_sequence_len is None:
|
||||
LOG.info(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
)
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"pad_to_sequence_len": False,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
def test_packing_autoset(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.INFO):
|
||||
cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.pad_to_sequence_len is True
|
||||
|
||||
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
||||
"""
|
||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||
|
||||
Reference in New Issue
Block a user