update config.qmd and rename option
This commit is contained in:
@@ -620,6 +620,11 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
|
# Sequence parallelism
|
||||||
|
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||||
|
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||||
|
sequence_parallel_degree:
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -763,8 +763,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.kd_top_k_before_softmax
|
] = self.cfg.kd_top_k_before_softmax
|
||||||
|
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sequence_parallel_size"
|
"sequence_parallel_degree"
|
||||||
] = self.cfg.sequence_parallel_size
|
] = self.cfg.sequence_parallel_degree
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -911,7 +911,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -423,9 +423,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
# Handle sequence parallelism
|
# Handle sequence parallelism
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
# Create base sampler for SP groups
|
# Create base sampler for SP groups
|
||||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
@@ -472,10 +472,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
# Handle sequence parallelism
|
# Handle sequence parallelism
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
# Create sampler for SP groups
|
# Create sampler for SP groups
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
# Create distributed sampler for the SP group
|
# Create distributed sampler for the SP group
|
||||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
@@ -570,7 +570,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
# Don't prepare dataloader for sequence parallelism
|
||||||
# We use a distributed sampler in this case
|
# We use a distributed sampler in this case
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
@@ -625,11 +625,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
# Don't prepare dataloader for sequence parallelism
|
||||||
# We use a distributed sampler in this case
|
# We use a distributed sampler in this case
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
@@ -949,14 +949,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
"""
|
"""
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
# parallel data collator
|
# parallel data collator
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_size
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
# Pass the partitioned sequence information to ring flash attention
|
# Pass the partitioned sequence information to ring flash attention
|
||||||
self._update_ring_flash_attn_params(
|
self._update_ring_flash_attn_params(
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallel_size: Optional[int] = field(
|
sequence_parallel_degree: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,21 +23,21 @@ def set_ring_attn_group(ring_attn_group: Any):
|
|||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_size: int):
|
def register_ring_attn(sequence_parallel_degree: int):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_size: Sequence parallelism factor.
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
"""
|
"""
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
"Enabling ring attention sequence parallelism: "
|
||||||
f"each sequence will be processed across {sequence_parallel_size} GPUs"
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
)
|
)
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
assert world_size % sequence_parallel_size == 0, (
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
f"sequence_parallel_size ({sequence_parallel_size}) "
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,11 +45,11 @@ def register_ring_attn(sequence_parallel_size: int):
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
group_assignments = {}
|
group_assignments = {}
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_size):
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
range(
|
range(
|
||||||
i * sequence_parallel_size,
|
i * sequence_parallel_degree,
|
||||||
(i + 1) * sequence_parallel_size,
|
(i + 1) * sequence_parallel_degree,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
@@ -65,4 +65,4 @@ def register_ring_attn(sequence_parallel_size: int):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
sequence_parallel_size (`int`):
|
sequence_parallel_degree (`int`):
|
||||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -107,10 +107,10 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sequence_parallel_size: int = 1
|
sequence_parallel_degree: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.sequence_parallel_size > 1:
|
if self.sequence_parallel_degree > 1:
|
||||||
# Get information about our position in the SP group
|
# Get information about our position in the SP group
|
||||||
sp_group = get_ring_attn_group()
|
sp_group = get_ring_attn_group()
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
@@ -183,7 +183,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
if self.sequence_parallel_size > 1:
|
if self.sequence_parallel_degree > 1:
|
||||||
features = self.apply_sequence_parallelism(features)
|
features = self.apply_sequence_parallelism(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
@@ -552,14 +552,13 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
if self.cfg.sequence_parallel_degree > 1:
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
# Initialize ring attention for sequence parallelism if enabled.
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
# This must be done after model initialization but before the first forward pass,
|
# model init but before the first forward pass, since it modifies flash
|
||||||
# as it modifies the flash attention implementation to use ring communication
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
# patterns for efficient sequence-parallel training across multiple GPUs.
|
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
sequence_parallel_size: int | None = 1
|
sequence_parallel_degree: int | None = 1
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
@@ -1107,10 +1107,10 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_config(cls, data):
|
def check_sequence_parallel_config(cls, data):
|
||||||
if data.get("sequence_parallel_size") > 1:
|
if data.get("sequence_parallel_degree") > 1:
|
||||||
if not data.get("flash_attention"):
|
if not data.get("flash_attention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_size > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
|
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
|
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,7 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_size
|
* cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -476,7 +476,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.floor(
|
math.floor(
|
||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.ceil(
|
math.ceil(
|
||||||
len(train_dataset)
|
len(train_dataset)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_size
|
* cfg.sequence_parallel_degree
|
||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def test_integration_with_config():
|
|||||||
],
|
],
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sequence_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
@@ -58,17 +58,17 @@ def test_integration_with_config():
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
# Verify sequence parallelism settings were properly processed
|
# Verify sequence parallelism settings were properly processed
|
||||||
assert cfg.sequence_parallel_size == 2
|
assert cfg.sequence_parallel_degree == 2
|
||||||
assert cfg.flash_attention is True
|
assert cfg.flash_attention is True
|
||||||
|
|
||||||
# Check if the sequence_parallel_size was propagated to the training args
|
# Check if the sequence_parallel_degree was propagated to the training args
|
||||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
# pylint: disable=unexpected-keyword-arg
|
# pylint: disable=unexpected-keyword-arg
|
||||||
training_args = AxolotlTrainingArguments(
|
training_args = AxolotlTrainingArguments(
|
||||||
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
|
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
assert training_args.sequence_parallel_size == 2
|
assert training_args.sequence_parallel_degree == 2
|
||||||
|
|
||||||
|
|
||||||
def test_ring_attn_group_creation():
|
def test_ring_attn_group_creation():
|
||||||
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
|
|||||||
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
||||||
|
|
||||||
# Register with sequence parallel size of 2
|
# Register with sequence parallel size of 2
|
||||||
register_ring_attn(sequence_parallel_size=2)
|
register_ring_attn(sequence_parallel_degree=2)
|
||||||
|
|
||||||
# Get the ring attention group
|
# Get the ring attention group
|
||||||
group = get_ring_attn_group()
|
group = get_ring_attn_group()
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class TestRingAttention:
|
|||||||
mock_new_group.return_value = mock_group
|
mock_new_group.return_value = mock_group
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
# Call register_ring_attn with size 4
|
||||||
register_ring_attn(sequence_parallel_size=4)
|
register_ring_attn(sequence_parallel_degree=4)
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
# Verify the number of calls without examining the arguments
|
||||||
assert mock_new_group.call_count == 2
|
assert mock_new_group.call_count == 2
|
||||||
@@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg):
|
|||||||
# Import the actual model class with appropriate mocks
|
# Import the actual model class with appropriate mocks
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
|
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||||
cfg = cfg | {
|
cfg = cfg | {
|
||||||
"sequence_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Should validate without errors
|
# Should validate without errors
|
||||||
config = AxolotlInputConfig(**cfg)
|
config = AxolotlInputConfig(**cfg)
|
||||||
assert config.sequence_parallel_size == 2
|
assert config.sequence_parallel_degree == 2
|
||||||
assert config.flash_attention is True
|
assert config.flash_attention is True
|
||||||
|
|
||||||
|
|
||||||
@@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg):
|
|||||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
|
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||||
cfg = cfg | {
|
cfg = cfg | {
|
||||||
"sequence_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user