Merge branch 'main' into diffusion
This commit is contained in:
@@ -147,7 +147,11 @@ def require_hopper(test_case):
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -157,6 +161,7 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
lt_val = (1 + rtol) * lt_val
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
|
||||
24
tests/utils/test_train.py
Normal file
24
tests/utils/test_train.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""test for train checkpoint utils"""
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
|
||||
def test_determine_last_checkpoint(temp_dir):
|
||||
cfg = DictDefault(
|
||||
output_dir=temp_dir,
|
||||
)
|
||||
for cpt_idx in [1, 9, 10, 20]:
|
||||
os.makedirs(
|
||||
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
|
||||
)
|
||||
|
||||
last_checkpoint = determine_last_checkpoint(cfg, update=False)
|
||||
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
|
||||
cfg.resume_from_checkpoint = None
|
||||
cfg.auto_resume_from_checkpoints = True
|
||||
determine_last_checkpoint(cfg, update=True)
|
||||
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
Reference in New Issue
Block a user