address PR code review
This commit is contained in:
@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||
{
|
||||
"base_model": "yujiepan/llama-4-tiny-random",
|
||||
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
||||
"trust_remote_code": True,
|
||||
"trust_remote_code": False,
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
|
||||
@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
|
||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
|
||||
loss0 = _last_logged_loss(trainer0)
|
||||
assert loss0 is not None
|
||||
|
||||
# Release baseline resources before starting aux-free run
|
||||
del model0, trainer0, dataset_meta0
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aux-free: plugin + noaux_tc
|
||||
cfg1 = DictDefault(dict(base_cfg))
|
||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||
|
||||
@@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
self.skipTest(
|
||||
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
||||
)
|
||||
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||
|
||||
Reference in New Issue
Block a user