diff --git a/_quarto.yml b/_quarto.yml index acb487258..8eb79f651 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -38,6 +38,7 @@ website: - docs/multi-node.qmd - docs/unsloth.qmd - docs/amd_hpc.qmd + - docs/ray-integration.qmd - section: "Dataset Formats" contents: docs/dataset-formats/* - section: "Reference" diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 641bd90b6..a90016ee4 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ fi RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/docker/Dockerfile b/docker/Dockerfile index 6b6baf751..aaaff23ef 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/docs/images/ray-cluster-dashboard.png b/docs/images/ray-cluster-dashboard.png new file mode 100644 index 000000000..f0b4beb56 Binary files /dev/null and b/docs/images/ray-cluster-dashboard.png differ diff --git a/docs/ray-integration.qmd b/docs/ray-integration.qmd new file mode 100644 index 000000000..0a2b45ef5 --- /dev/null +++ b/docs/ray-integration.qmd @@ -0,0 +1,93 @@ +--- +title: Ray Train integration +description: How to use Axolotl with Ray Train +--- + +Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node. + +With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training. + +## Ray cluster setup + +A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html + +Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts). + +## Sanity check + +To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node: + +```bash +ray status +``` + +The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this: + + +``` +Node status +--------------------------------------------------------------- +Active: + 1 head +Idle: + 2 4xL40S:48CPU-384GB +Pending: + (no pending nodes) +Recent failures: + (no failures) + +Resources +--------------------------------------------------------------- +Usage: + 0.0/96.0 CPU + 0.0/8.0 GPU + 0B/800.00GiB memory + 0B/229.57GiB object_store_memory + +Demands: + (no resource demands) +``` + +You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html). + + +## Configuring training with Ray Train + +You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`. + +The key parameters to note here are: + +```yaml +... +use_ray: true +ray_num_workers: 4 +# optional +resources_per_worker: + GPU: 1 +... +``` + +- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file. +- `ray_num_workers`: This is the number of workers/GPUs to use for training. +- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do + +```yaml +resources_per_worker: + accelerator_type:L40S: 0.001 +``` + +## Launching training + +You can simply run the following command on the head node: + +```bash +axolotl train examples/llama-3/lora-1b-ray.yml --use-ray +``` + +This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes. + +You can also monitor training progress on the Ray dashboard. + +Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following: + +![Ray dashboard](./images/ray-cluster-dashboard.png) diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml new file mode 100644 index 000000000..0e597a204 --- /dev/null +++ b/examples/llama-3/lora-1b-ray.yml @@ -0,0 +1,79 @@ +base_model: NousResearch/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +debug: +deepspeed: deepspeed_configs/zero3.json +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "<|end_of_text|>" + +use_ray: true +ray_num_workers: 4 diff --git a/setup.py b/setup.py index ac0c96def..370eb7297 100644 --- a/setup.py +++ b/setup.py @@ -150,5 +150,8 @@ setup( "lomo-optim==0.1.1", "torch-optimi==0.2.1", ], + "ray": [ + "ray[train]", + ], }, ) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 0618e07f1..a5865be1c 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -25,6 +25,8 @@ class TrainerCliArgs: merge_lora: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) + main_process_port: Optional[int] = field(default=None) + num_processes: Optional[int] = field(default=None) @dataclass diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 43e2de3db..801fbc80d 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -67,8 +67,23 @@ def train(config: str, accelerate: bool, **kwargs) -> None: # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() + if "use_ray" in kwargs and kwargs["use_ray"]: + accelerate = False + if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] + accelerate_args = [] + if "main_process_port" in kwargs: + main_process_port = kwargs.pop("main_process_port", None) + accelerate_args.append("--main_process_port") + accelerate_args.append(str(main_process_port)) + if "num_processes" in kwargs: + num_processes = kwargs.pop("num_processes", None) + accelerate_args.append("--num-processes") + accelerate_args.append(str(num_processes)) + + base_cmd = ["accelerate", "launch"] + base_cmd.extend(accelerate_args) + base_cmd.extend(["-m", "axolotl.cli.train"]) if config: base_cmd.append(config) cmd = build_command(base_cmd, kwargs) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 9e3ae1cc3..7ac15e04f 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Union import fire +from accelerate import Accelerator from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser @@ -15,6 +16,7 @@ from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train +from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) @@ -63,7 +65,47 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: return_remaining_strings=True ) - do_train(parsed_cfg, parsed_cli_args) + if parsed_cfg.use_ray: + from ray.train import RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args} + trainer = TorchTrainer( + ray_train_func, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=parsed_cfg.ray_num_workers, + resources_per_worker=parsed_cfg.resources_per_worker.to_dict(), + use_gpu=True, + ), + run_config=RunConfig( + name=parsed_cfg.ray_run_name, + storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(), + ), + ) + return trainer.fit() + return do_train(parsed_cfg, parsed_cli_args) + + +def ray_train_func(kwargs: dict): + # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict) + # also renormalize the config now that TorchTrainer has spawned distributed workers + cfg = DictDefault(kwargs["cfg"]) + normalize_config(cfg) + + # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype + resolve_dtype(cfg) + + # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict + if cfg.deepspeed: + cfg.deepspeed = cfg.deepspeed.to_dict() + + # initialize accelerator before model instantiation + Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps) + + kwargs["cfg"] = cfg + + do_train(**kwargs) if __name__ == "__main__": diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 0bd400f6b..8b5e0074c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -141,7 +141,9 @@ def train( model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: + if ( + cfg.local_rank == 0 and not cfg.use_ray + ): # ray workers don't have access to this signal def terminate_handler(_, __, model_weakref): if model_weakref() is not None: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index c23359f34..7ddff6219 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -1,4 +1,5 @@ """Module for working with config dicts""" +import json import logging import os from typing import Optional @@ -56,33 +57,10 @@ def choose_device(cfg): cfg.device_map = None -def normalize_config(cfg): - # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( - cfg.batch_size // cfg.micro_batch_size - ) - cfg.batch_size = ( - cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps - ) - if cfg.eval_batch_size is None: - cfg.eval_batch_size = cfg.micro_batch_size - cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) - cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) - cfg.eval_table_size = cfg.eval_table_size or 0 - cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 - cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [ - "sacrebleu", - "comet", - "ter", - "chrf", - ] - choose_device(cfg) - cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 - if cfg.ddp: - cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} - cfg.batch_size = cfg.batch_size * cfg.world_size - - if cfg.bf16 == "auto": +def resolve_dtype(cfg): + if ( + cfg.bf16 == "auto" and not cfg.use_ray + ): # if we use ray we want to defer this check to the worker node if is_torch_bf16_gpu_available(): LOG.debug("bf16 support detected, enabling for this configuration.") cfg.bf16 = True @@ -110,6 +88,43 @@ def normalize_config(cfg): else: cfg.torch_dtype = torch.float32 + +def normalize_config(cfg): + # setup some derived config / hyperparams + cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( + cfg.batch_size // cfg.micro_batch_size + ) + cfg.batch_size = ( + cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps + ) + if cfg.eval_batch_size is None: + cfg.eval_batch_size = cfg.micro_batch_size + cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) + cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 + cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [ + "sacrebleu", + "comet", + "ter", + "chrf", + ] + choose_device(cfg) + cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 + if cfg.ddp: + cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} + cfg.batch_size = cfg.batch_size * cfg.world_size + + if not cfg.use_ray: + # delay resolving dtype until on worker node when launching with ray + resolve_dtype(cfg) + + if cfg.deepspeed: + if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed): + ds_config_path = cfg.deepspeed + with open(ds_config_path, encoding="utf-8") as f: + cfg.deepspeed = json.load(f) + if cfg.saves_per_epoch: save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4f0fa4c29..dc8897863 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -607,6 +607,30 @@ class GradioConfig(BaseModel): gradio_temperature: Optional[float] = None +class RayConfig(BaseModel): + """Ray launcher configuration subset""" + + use_ray: bool = Field(default=False) + ray_run_name: Optional[str] = Field( + default=None, + metadata={ + "help": "The training results will be saved at `saves/ray_run_name`." + }, + ) + ray_num_workers: int = Field( + default=1, + metadata={ + "help": "The number of workers for Ray training. Default is 1 worker." + }, + ) + resources_per_worker: dict = Field( + default_factory=lambda: {"GPU": 1}, + metadata={ + "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." + }, + ) + + # pylint: disable=too-many-public-methods,too-many-ancestors class AxolotlInputConfig( ModelInputConfig, @@ -619,6 +643,7 @@ class AxolotlInputConfig( CometConfig, LISAConfig, GradioConfig, + RayConfig, RemappedParameters, DeprecatedParameters, BaseModel, diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index bdbd99587..bb1874b0b 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -74,15 +74,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -139,15 +137,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -214,15 +210,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -293,15 +287,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -367,15 +359,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -439,15 +429,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -520,15 +508,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -605,15 +591,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -680,15 +664,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) @@ -755,15 +737,13 @@ class TestMultiGPULlama: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py index 32bb6a3e1..2b9884848 100644 --- a/tests/e2e/multigpu/test_qwen2.py +++ b/tests/e2e/multigpu/test_qwen2.py @@ -86,14 +86,12 @@ class TestMultiGPUQwen2: execute_subprocess_async( [ - "accelerate", - "launch", + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), "--num-processes", "2", - "--main_process_port", + "--main-process-port", f"{get_torch_dist_unique_port()}", - "-m", - "axolotl.cli.train", - str(Path(temp_dir) / "config.yaml"), ] ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py new file mode 100644 index 000000000..d7e4ddfcf --- /dev/null +++ b/tests/e2e/multigpu/test_ray.py @@ -0,0 +1,137 @@ +""" +E2E tests for multigpu post-training use Ray Train +""" + +import logging +import os +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from e2e.utils import check_tensorboard + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) +os.environ["WANDB_DISABLED"] = "true" + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +class TestMultiGPURay: + """ + Test cases for AnyScale Ray post training + """ + + def test_lora_ddp(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 2048, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 4, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "use_ray": True, + "ray_num_workers": 2, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--use-ray", + "--ray-num-workers", + "2", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + ) + + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 2], + ) + def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 2048, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": gradient_accumulation_steps, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--use-ray", + "--ray-num-workers", + "2", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + )