From d0d26d5064e05be1bc1696c6df45b9a65d2307b6 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 22 Jan 2026 03:52:45 +0530 Subject: [PATCH] feat: Add GDPO Support (#3353) * gdpo support - test left * lint * fixxes for vllm serv * test advantages * docss * lint * lint = * gdpo simple + lint * lint nit * example * lint * trl 0.27.0 * blocklist * test assert rmv * add validation check for GDPO + sum_then_normalize --------- Co-authored-by: Wing Lian --- docs/rlhf.qmd | 97 ++++ examples/llama-3/qlora-1b-gdpo.yaml | 68 +++ requirements.txt | 2 +- src/axolotl/core/builders/rl.py | 11 +- src/axolotl/core/trainers/grpo/__init__.py | 5 + src/axolotl/utils/data/rl.py | 2 +- src/axolotl/utils/schemas/enums.py | 1 + src/axolotl/utils/schemas/trl.py | 10 + src/axolotl/utils/schemas/validation.py | 13 + tests/core/test_builders.py | 1 - tests/e2e/multigpu/solo/test_gdpo.py | 538 +++++++++++++++++++++ 11 files changed, 742 insertions(+), 6 deletions(-) create mode 100644 examples/llama-3/qlora-1b-gdpo.yaml create mode 100644 tests/e2e/multigpu/solo/test_gdpo.py diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 1eea42036..135b3038c 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to: - [Kahneman-Tversky Optimization (KTO)](#kto) - [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Group Relative Policy Optimization (GRPO)](#grpo) +- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo) ## RLHF using Axolotl @@ -720,6 +721,102 @@ trl: For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). +### GDPO + +GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them. + +::: {.callout-tip} +Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results. +::: + +Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242) + +GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation. + +```yaml +base_model: Qwen/Qwen2.5-1.5B-Instruct + +vllm: + host: 0.0.0.0 + port: 8000 + tensor_parallel_size: 2 + gpu_memory_utilization: 0.85 + +rl: gdpo + +trl: + beta: 0.001 + max_completion_length: 256 + use_vllm: true + num_generations: 4 + reward_funcs: + - rewards.format_reward + - rewards.correctness_reward + reward_weights: [1.0, 2.0] + +datasets: + - path: openai/gsm8k + name: main + type: rewards.oai_gsm8k_transform +``` + +You can also use GRPO with explicit aggregation control: + +```yaml +rl: grpo +trl: + multi_objective_aggregation: normalize_then_sum # GDPO behavior + # or: sum_then_normalize # Default GRPO behavior +``` + +#### GDPO vs GRPO + +| Aspect | GRPO | GDPO | +|--------|------|------| +| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` | +| **Multi-reward** | May collapse advantages | Preserves reward signals | +| **Single reward** | Standard behavior | Equivalent to GRPO | + +#### Why GDPO? + +When using multiple rewards with GRPO, different reward combinations can produce identical advantages: + +``` +# Example: format + correctness rewards +[format=0, correct=3] → sum=3 +[format=1, correct=2] → sum=3 ← GRPO sees these as equal! +[format=2, correct=1] → sum=3 +[format=3, correct=0] → sum=3 +``` + +GDPO normalizes each reward independently, preserving their relative differences. + +#### Reward Functions + +GDPO uses the same reward function format as GRPO: + +```python +# rewards.py +def format_reward(completions, **kwargs) -> list[float]: + return [1.0 if len(c) > 10 else 0.0 for c in completions] + +def correctness_reward(completions, answers, **kwargs) -> list[float]: + rewards = [] + for completion, answer in zip(completions, answers): + # Your scoring logic here + rewards.append(score) + return rewards +``` + +#### Sequence Parallelism + +GDPO supports sequence parallelism for long-context training: + +```yaml +rl: gdpo +context_parallel_size: 2 +``` + ### SimPO SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. diff --git a/examples/llama-3/qlora-1b-gdpo.yaml b/examples/llama-3/qlora-1b-gdpo.yaml new file mode 100644 index 000000000..d806fcf26 --- /dev/null +++ b/examples/llama-3/qlora-1b-gdpo.yaml @@ -0,0 +1,68 @@ +base_model: meta-llama/Llama-3.2-1B-Instruct + +chat_template: llama3 + +rl: gdpo + +trl: + beta: 0.001 + max_completion_length: 128 + num_generations: 2 + temperature: 0.7 + top_p: 0.95 + + use_vllm: false + + + multi_objective_aggregation: normalize_then_sum + + reward_funcs: + - rwd.format_reward + - rwd.correctness_reward + reward_weights: [1.0, 2.0] + + log_completions: true + num_completions_to_print: 3 + scale_rewards: true + +datasets: + - path: openai/gsm8k + name: main + split: train[:1000] + type: rwd.gsm8k_transform + +val_set_size: 0.0 +output_dir: ./outputs/llama3-gdpo-out + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +max_steps: 100 + +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-5 +weight_decay: 0.01 +warmup_steps: 10 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +flash_attention: true +logging_steps: 1 +save_steps: 50 +save_safetensors: true + +special_tokens: + pad_token: "<|end_of_text|>" + + +seed: 42 diff --git a/requirements.txt b/requirements.txt index 64fe1b240..2b5ec0c38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ transformers==4.57.6 accelerate==1.12.0 datasets==4.5.0 deepspeed>=0.18.3 -trl==0.25.1 +trl==0.27.0 hf_xet==1.2.0 kernels==0.11.5 trackio>=0.13.0 diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 0ceb80008..0bd2eedfc 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -52,12 +52,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_cls = None trainer_cls_args = [self.model] - if self.cfg.rl is RLType.GRPO: + if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: trainer_cls = GRPOStrategy.get_trainer_class( sequence_parallel=self.cfg.context_parallel_size > 1 ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) - trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) elif self.cfg.rl in [RLType.DPO, RLType.IPO]: @@ -147,6 +146,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig + # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length + blocklist_args_kwargs = ["max_prompt_length"] training_args_kwargs["desirable_weight"] = ( self.cfg.kto_desirable_weight or 1.0 @@ -155,10 +156,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - elif self.cfg.rl is RLType.GRPO: + elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() + if self.cfg.rl is RLType.GDPO: + training_args_kwargs.setdefault( + "multi_objective_aggregation", "normalize_then_sum" + ) elif self.cfg.rl in [RLType.DPO, RLType.IPO]: training_args_cls = AxolotlDPOConfig diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 7f28cb8d4..e611b96ea 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -129,6 +129,11 @@ class GRPOStrategy: if trl.rollout_func: grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) + if trl.multi_objective_aggregation is not None: + grpo_args_kwargs["multi_objective_aggregation"] = ( + trl.multi_objective_aggregation + ) + return grpo_args_kwargs @classmethod diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index f7a5ec04c..5ea9e55e0 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -173,7 +173,7 @@ def _drop_long_sequences( return (len_prompt + len_completion) <= sequence_len - if rl is RLType.GRPO: + if rl in {RLType.GRPO, RLType.GDPO}: return True raise ValueError("Unknown RL type") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index f86d1a191..b67888e0f 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -26,6 +26,7 @@ class RLType(str, Enum): """RL trainer type configuration subset""" DPO = "dpo" + GDPO = "gdpo" GRPO = "grpo" IPO = "ipo" ORPO = "orpo" diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index d24d6f477..ff96f44ce 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -179,3 +179,13 @@ class TRLConfig(BaseModel): "description": "Path to custom rollout function. Must be importable from current dir." }, ) + multi_objective_aggregation: ( + Literal["sum_then_normalize", "normalize_then_sum"] | None + ) = Field( + default=None, + json_schema_extra={ + "description": "Multi-objective reward aggregation strategy. " + "'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. " + "'normalize_then_sum' (GDPO): normalizes each reward independently, then sums." + }, + ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index bf054d353..bb9c3c673 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -746,6 +746,19 @@ class RLValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_gdpo(cls, data): + if ( + data.get("rl") == "gdpo" + and data.get("trl", {}).get("multi_objective_aggregation") + == "sum_then_normalize" + ): + raise ValueError( + "`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected" + ) + return data + class OptimizationValidationMixin: """Validation methods related to optimization and performance.""" diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index f9db4d013..c2d81cbcb 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -311,7 +311,6 @@ class TestHFRLTrainerBuilder: # KTO specific assert training_arguments.desirable_weight == 1.0 assert training_arguments.undesirable_weight == 1.0 - assert training_arguments.max_prompt_length == 512 def _write_rewards_file(self, rewards_dir: Path): """ diff --git a/tests/e2e/multigpu/solo/test_gdpo.py b/tests/e2e/multigpu/solo/test_gdpo.py new file mode 100644 index 000000000..2014f7f5e --- /dev/null +++ b/tests/e2e/multigpu/solo/test_gdpo.py @@ -0,0 +1,538 @@ +""" +GDPO test suite + +GDPO uses TRL's multi_objective_aggregation="normalize_then_sum" for +per-reward normalization in multi-reward RL training. +""" + +import os +import random +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.multigpu.solo.test_grpo import recursive_kill, start_vllm +from tests.e2e.utils import require_vllm + + +@pytest.mark.skip(reason="flaky vllm tests in modal") +class TestGDPO: + """Test case for GDPO training using TRL's native multi-objective aggregation.""" + + def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + with open(f"rewards_gdpo_{suffix}.py", "w", encoding="utf-8") as fout: + fout.write( + """import random + +def format_reward(prompts, completions, **kwargs) -> list[float]: + return [1.0 if len(c) > 10 else 0.0 for c in completions] + +def correctness_reward(prompts, completions, **kwargs) -> list[float]: + return [random.uniform(-1, 3) for _ in completions] + +def safety_reward(prompts, completions, **kwargs) -> list[float]: + return [1.0 if 'error' not in c.lower() else 0.0 for c in completions] + +def single_reward(prompts, completions, **kwargs) -> list[float]: + return [random.uniform(0, 1) for _ in completions] + +def oai_gsm8k_transform(cfg, *args, **kwargs): + def transform_fn(example, tokenizer=None): + label = example["answer"].split("####")[-1].strip().replace(",", "") + return { + "prompt": [{"role": "user", "content": example["question"]}], + "answer": label, + } + return transform_fn, {"remove_columns": ["question"]} +""" + ) + + @pytest.mark.parametrize("num_gpus", [1, 2]) + @require_vllm + def test_gdpo_multi_reward_lora(self, temp_dir, num_gpus): + """Test GDPO with multiple reward functions using LoRA.""" + rnd_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "gdpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + f"rewards_gdpo_{rnd_suffix}.format_reward", + f"rewards_gdpo_{rnd_suffix}.correctness_reward", + ], + "reward_weights": [1.0, 2.0], + "scale_rewards": True, + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + "save_first_step": False, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_gpus), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) + finally: + recursive_kill(vllm_process) + + @require_vllm + def test_gdpo_three_rewards(self, temp_dir): + """Test GDPO with three reward functions (format, correctness, safety).""" + rnd_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "gdpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + f"rewards_gdpo_{rnd_suffix}.format_reward", + f"rewards_gdpo_{rnd_suffix}.correctness_reward", + f"rewards_gdpo_{rnd_suffix}.safety_reward", + ], + "reward_weights": [1.0, 2.0, 1.5], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) + finally: + recursive_kill(vllm_process) + + @require_vllm + def test_gdpo_single_reward_fallback(self, temp_dir): + """Test GDPO with single reward.""" + rnd_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "gdpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + f"rewards_gdpo_{rnd_suffix}.single_reward", + ], + "reward_weights": [1.0], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) + finally: + recursive_kill(vllm_process) + + @require_vllm + def test_gdpo_fft(self, temp_dir): + """Test GDPO with full fine-tuning (no adapter).""" + rnd_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "gdpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + f"rewards_gdpo_{rnd_suffix}.format_reward", + f"rewards_gdpo_{rnd_suffix}.correctness_reward", + ], + "reward_weights": [1.0, 2.0], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform", + }, + ], + # No adapter - full fine-tuning + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) + finally: + recursive_kill(vllm_process) + + @require_vllm + def test_gdpo_sequence_parallel(self, temp_dir): + """Test GDPO with sequence parallelism.""" + rnd_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "gdpo", + "context_parallel_size": 2, + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + f"rewards_gdpo_{rnd_suffix}.format_reward", + f"rewards_gdpo_{rnd_suffix}.correctness_reward", + ], + "reward_weights": [1.0, 2.0], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) + finally: + recursive_kill(vllm_process)