Compare commits
1 Commits
seq-parall
...
grpo-ref-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9ebff087c |
4
.github/workflows/multi-gpu-e2e.yml
vendored
4
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,10 +4,6 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
include_tokens_per_second:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.15.1
|
||||
trl==0.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
|
||||
@@ -59,7 +59,6 @@ from axolotl.core.training_args import (
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -747,11 +746,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
if self.cfg.sp_ulysses_degree:
|
||||
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
||||
USPRingAttnType.ZIGZAG,
|
||||
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
# cleanup the ref_model if we have a peft model passed in
|
||||
# TODO remove this after next major trl release
|
||||
if (
|
||||
self.ref_model # pylint: disable=access-member-before-definition
|
||||
and is_peft_model(self.model)
|
||||
):
|
||||
del self.ref_model
|
||||
self.ref_model = None
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
@@ -78,6 +87,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
@@ -99,10 +109,8 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
|
||||
@@ -206,16 +206,6 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sp_ulysses_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
sp_ring_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
||||
|
||||
from axolotl.utils.distributed import get_world_size, get_rank
|
||||
|
||||
|
||||
class USPRingAttnType(Enum):
|
||||
BASIC = "basic"
|
||||
ZIGZAG = "zigzag"
|
||||
STRIPE = "stripe"
|
||||
|
||||
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||
|
||||
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||
|
||||
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
||||
fn = EXTRACT_FUNC_DICT["basic"]
|
||||
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
||||
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
||||
|
||||
# map bad key upstream
|
||||
elif ring_impl_type == USPRingAttnType.STRIPE:
|
||||
fn = EXTRACT_FUNC_DICT["strip"]
|
||||
|
||||
world_size = get_world_size()
|
||||
rd = world_size // sp_ulysses_degree
|
||||
|
||||
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
||||
|
||||
def set_usp_parallel_group(sp_ulysses_degree):
|
||||
"""
|
||||
setup distributed parallel group for USP attention
|
||||
make sure this gets called before building any USP attention modules
|
||||
:param sp_ulysses_degree:
|
||||
:return:
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
sp_ring_degree = world_size // sp_ulysses_degree
|
||||
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
||||
@@ -1,36 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from yunchang import LongContextAttention
|
||||
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||
|
||||
|
||||
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||
|
||||
def flash_attention_forward(
|
||||
module: torch.nn.Module, # pylint: disable=unused-argument
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
attn_output = usp_attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=scaling,
|
||||
causal=True,
|
||||
softcap=softcap,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
return flash_attention_forward
|
||||
@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -53,7 +53,6 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sp_extract_fn: Optional[Callable] = None
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
@@ -122,10 +121,6 @@ class DataCollatorForSeq2Seq:
|
||||
|
||||
return features
|
||||
|
||||
def seq_parallel_split(self, features):
|
||||
if self.sp_extract_fn:
|
||||
pass
|
||||
return features
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
@@ -832,8 +831,6 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: Optional[bool] = None
|
||||
|
||||
sp_ulysses_degree: Optional[int] = None
|
||||
|
||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||
unsloth_lora_mlp: Optional[bool] = None
|
||||
unsloth_lora_qkv: Optional[bool] = None
|
||||
|
||||
@@ -86,12 +86,6 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_distributed():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
|
||||
@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
|
||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing or cfg.sp_ulysses_degree:
|
||||
elif cfg.sample_packing:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
|
||||
Reference in New Issue
Block a user