address PR code review

This commit is contained in:
Wing Lian
2026-03-22 17:23:12 +00:00
parent 0a566d7a15
commit 6636e5de7e
9 changed files with 51 additions and 17 deletions

View File

@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
{
"base_model": "yujiepan/llama-4-tiny-random",
"tokenizer_config": "yujiepan/llama-4-tiny-random",
"trust_remote_code": True,
"trust_remote_code": False,
"flash_attention": False,
"sequence_len": 512,
"bf16": False,

View File

@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import gc
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Release baseline resources before starting aux-free run
del model0, trainer0, dataset_meta0
gc.collect()
torch.cuda.empty_cache()
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"