more fixes to flex for fsdp2

This commit is contained in:
Wing Lian
2025-04-06 14:24:50 -04:00
parent b5a51c378b
commit 7e410ab480
4 changed files with 153 additions and 35 deletions

View File

@@ -486,7 +486,7 @@ class TestMultiGPULlama:
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"fsdp": [
"auto_wrap",
@@ -529,7 +529,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):