Compare commits
3 Commits
fix/diffus
...
v0.12.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
160ba459ea | ||
|
|
7a09f76644 | ||
|
|
47304c7f8a |
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.12.0"
|
__version__ = "0.12.1"
|
||||||
|
|||||||
@@ -123,9 +123,10 @@ def train(
|
|||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
# Process each configuration
|
# Process each configuration
|
||||||
for cfg_file in generate_config_files(config, sweep):
|
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||||
try:
|
try:
|
||||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
use_exec = is_group is not True
|
||||||
|
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
if not sweep:
|
if not sweep:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||||
"""Generate list of configuration files to process."""
|
"""
|
||||||
|
Generate list of configuration files to process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Base configuration file
|
||||||
|
sweep: Sweep configuration file
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of configuration file name and whether this is a group of configurations
|
||||||
|
"""
|
||||||
|
|
||||||
if not sweep:
|
if not sweep:
|
||||||
yield config
|
yield config, False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load sweep and base configurations
|
# Load sweep and base configurations
|
||||||
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
|
|
||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
|
is_group = len(permutations) > 1
|
||||||
for permutation in permutations:
|
for permutation in permutations:
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
)
|
)
|
||||||
yaml.dump(permutation, temp_file)
|
yaml.dump(permutation, temp_file)
|
||||||
temp_file.close()
|
temp_file.close()
|
||||||
yield temp_file.name
|
yield temp_file.name, is_group
|
||||||
|
|
||||||
|
|
||||||
def launch_training(
|
def launch_training(
|
||||||
@@ -97,6 +109,7 @@ def launch_training(
|
|||||||
cloud: str | None,
|
cloud: str | None,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
launcher_args: list[str] | None = None,
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training with the given configuration."""
|
"""Execute training with the given configuration."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -105,11 +118,14 @@ def launch_training(
|
|||||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||||
elif launcher:
|
elif launcher:
|
||||||
if launcher == "accelerate":
|
if launcher == "accelerate":
|
||||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "torchrun":
|
elif launcher == "torchrun":
|
||||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "python":
|
elif launcher == "python":
|
||||||
_launch_python_training(cfg_file, kwargs)
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
elif launcher is None:
|
||||||
|
# handle ray train launch
|
||||||
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _launch_cloud_training(
|
def _launch_cloud_training(
|
||||||
@@ -136,7 +152,10 @@ def _launch_cloud_training(
|
|||||||
|
|
||||||
|
|
||||||
def _launch_accelerate_training(
|
def _launch_accelerate_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via accelerate launcher."""
|
"""Execute training via accelerate launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -161,11 +180,20 @@ def _launch_accelerate_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_torchrun_training(
|
def _launch_torchrun_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via torchrun launcher."""
|
"""Execute training via torchrun launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -178,7 +206,13 @@ def _launch_torchrun_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ class BaseCliTest:
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock:
|
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||||
|
|
||||||
|
with patch(mock_fn) as mock:
|
||||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
assert mock.called
|
assert mock.called
|
||||||
@@ -65,8 +67,12 @@ class BaseCliTest:
|
|||||||
if train:
|
if train:
|
||||||
expected.append("--shard=False")
|
expected.append("--shard=False")
|
||||||
|
|
||||||
assert mock.call_args.args[0] == expected
|
if command == "train":
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
assert mock.call_args.args[0] == "accelerate"
|
||||||
|
assert mock.call_args.args[1] == expected
|
||||||
|
else:
|
||||||
|
assert mock.call_args.args[0] == expected
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to torchrun
|
# Verify launcher args are passed to torchrun
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "torchrun"
|
assert called_cmd[0] == "torchrun"
|
||||||
assert "--nproc_per_node=2" in called_cmd
|
assert "--nproc_per_node=2" in called_cmd
|
||||||
assert "--nnodes=1" in called_cmd
|
assert "--nnodes=1" in called_cmd
|
||||||
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to accelerate
|
# Verify launcher args are passed to accelerate
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||||
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify no launcher args contamination
|
# Verify no launcher args contamination
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
# Should not contain any extra launcher args
|
# Should not contain any extra launcher args
|
||||||
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
# Verify launcher args
|
# Verify launcher args
|
||||||
assert "--nproc_per_node=8" in called_cmd
|
assert "--nproc_per_node=8" in called_cmd
|
||||||
# Verify axolotl args are also present
|
# Verify axolotl args are also present
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
from tests.e2e.utils import (
|
||||||
|
check_tensorboard,
|
||||||
|
require_torch_2_7_0,
|
||||||
|
require_torch_lt_2_6_0,
|
||||||
|
)
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
@@ -139,3 +143,71 @@ class TestMultiGPURay:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--use-ray",
|
||||||
|
"--ray-num-workers",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user