Merge branch 'main' into diffusion

This commit is contained in:
Dan Saunders
2025-08-18 10:07:55 -04:00
committed by GitHub
30 changed files with 687 additions and 183 deletions

View File

@@ -147,7 +147,11 @@ def require_hopper(test_case):
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
temp_run_dir: str,
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -157,6 +161,7 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(

24
tests/utils/test_train.py Normal file
View File

@@ -0,0 +1,24 @@
"""test for train checkpoint utils"""
import os
from axolotl.utils.dict import DictDefault
from axolotl.utils.train import determine_last_checkpoint
def test_determine_last_checkpoint(temp_dir):
cfg = DictDefault(
output_dir=temp_dir,
)
for cpt_idx in [1, 9, 10, 20]:
os.makedirs(
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
)
last_checkpoint = determine_last_checkpoint(cfg, update=False)
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
cfg.resume_from_checkpoint = None
cfg.auto_resume_from_checkpoints = True
determine_last_checkpoint(cfg, update=True)
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")