address PR code review
This commit is contained in:
@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
|
||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
|
||||
loss0 = _last_logged_loss(trainer0)
|
||||
assert loss0 is not None
|
||||
|
||||
# Release baseline resources before starting aux-free run
|
||||
del model0, trainer0, dataset_meta0
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aux-free: plugin + noaux_tc
|
||||
cfg1 = DictDefault(dict(base_cfg))
|
||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||
|
||||
Reference in New Issue
Block a user