Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
2b890ead05 fsdp fft loading on meta device 2024-08-11 22:18:04 -04:00
3 changed files with 14 additions and 10 deletions

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s> hi there!. goodbye farewell</s>
``` ```
We can check that the right tokens are ignored by comparing the labels We can check that the right tokens are ingored by comparing the labels
to each token: to each token:
```python ```python

View File

@@ -321,8 +321,6 @@ class ModelInputConfig(BaseModel):
) )
trust_remote_code: Optional[bool] = None trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code") @field_validator("trust_remote_code")
@classmethod @classmethod
def hint_trust_remote_code(cls, trust_remote_code): def hint_trust_remote_code(cls, trust_remote_code):
@@ -616,8 +614,6 @@ class AxolotlInputConfig(
flash_attn_fuse_mlp: Optional[bool] = None flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None flash_optimum: Optional[bool] = None
eager_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None

View File

@@ -655,11 +655,19 @@ def load_model(
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained( if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0:
base_model, with init_empty_weights():
config=model_config, model = AutoModelForCausalLM.from_pretrained(
**model_kwargs, base_model,
) config=model_config,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
if cfg.flash_attention and not inference: if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (