address PR code review

This commit is contained in:
Wing Lian
2026-03-22 17:23:12 +00:00
parent 0a566d7a15
commit 6636e5de7e
9 changed files with 51 additions and 17 deletions

View File

@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
{
"base_model": "yujiepan/llama-4-tiny-random",
"tokenizer_config": "yujiepan/llama-4-tiny-random",
"trust_remote_code": True,
"trust_remote_code": False,
"flash_attention": False,
"sequence_len": 512,
"bf16": False,

View File

@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import gc
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Release baseline resources before starting aux-free run
del model0, trainer0, dataset_meta0
gc.collect()
torch.cuda.empty_cache()
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"

View File

@@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
self.skipTest(
"Cannot safely test deferred EP group resolution when a process group is already initialized"
)
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)