fix zero3 (#1994)
This commit is contained in:
@@ -14,15 +14,6 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -24,15 +24,6 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -20,15 +20,6 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ tokenizers>=0.20.1
|
|||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.0.1
|
accelerate==1.0.1
|
||||||
datasets==3.0.1
|
datasets==3.0.1
|
||||||
deepspeed==0.14.4
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -40,7 +40,10 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import (
|
||||||
|
HfTrainerDeepSpeedConfig,
|
||||||
|
is_deepspeed_zero3_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
@@ -705,6 +708,38 @@ class ModelLoader:
|
|||||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
def build_model(self, qlora_fsdp) -> bool:
|
def build_model(self, qlora_fsdp) -> bool:
|
||||||
|
def _configure_zero3_memory_efficient_loading():
|
||||||
|
"""
|
||||||
|
Set the deepspeed config to load the model into RAM first before moving to VRAM.
|
||||||
|
|
||||||
|
We need to return hf_ds_cfg as it needs to exist before model loading.
|
||||||
|
"""
|
||||||
|
hf_ds_cfg = None
|
||||||
|
|
||||||
|
if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||||
|
hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)
|
||||||
|
hf_ds_cfg.fill_match(
|
||||||
|
"train_micro_batch_size_per_gpu", self.cfg.micro_batch_size
|
||||||
|
)
|
||||||
|
hf_ds_cfg.fill_match(
|
||||||
|
"gradient_accumulation_steps", self.cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
hf_ds_cfg.fill_match(
|
||||||
|
"train_batch_size",
|
||||||
|
int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
* self.cfg.micro_batch_size
|
||||||
|
* self.cfg.gradient_accumulation_steps,
|
||||||
|
)
|
||||||
|
if "device_map" in self.model_kwargs:
|
||||||
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
|
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
|
||||||
|
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
|
||||||
|
lambda: True
|
||||||
|
)
|
||||||
|
|
||||||
|
return hf_ds_cfg
|
||||||
|
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
if ( # pylint: disable=condition-evals-to-constant)
|
if ( # pylint: disable=condition-evals-to-constant)
|
||||||
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
||||||
@@ -753,6 +788,8 @@ class ModelLoader:
|
|||||||
if "device_map" in self.model_kwargs:
|
if "device_map" in self.model_kwargs:
|
||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
@@ -846,6 +883,8 @@ class ModelLoader:
|
|||||||
if "device_map" in self.model_kwargs:
|
if "device_map" in self.model_kwargs:
|
||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
|
|||||||
Reference in New Issue
Block a user