Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f09f48d8f | ||
|
|
b44546df6f | ||
|
|
967fbf8152 | ||
|
|
c144a1ae65 | ||
|
|
68a3c7678a | ||
|
|
f18925fb4b |
@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
|||||||
hi there!. goodbye farewell</s>
|
hi there!. goodbye farewell</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
We can check that the right tokens are ingored by comparing the labels
|
We can check that the right tokens are ignored by comparing the labels
|
||||||
to each token:
|
to each token:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@@ -321,6 +321,8 @@ class ModelInputConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
@@ -614,6 +616,8 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_mlp: Optional[bool] = None
|
flash_attn_fuse_mlp: Optional[bool] = None
|
||||||
flash_optimum: Optional[bool] = None
|
flash_optimum: Optional[bool] = None
|
||||||
|
|
||||||
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
|||||||
@@ -655,19 +655,11 @@ def load_model(
|
|||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0:
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
with init_empty_weights():
|
base_model,
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
config=model_config,
|
||||||
base_model,
|
**model_kwargs,
|
||||||
config=model_config,
|
)
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
config=model_config,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.flash_attention and not inference:
|
if cfg.flash_attention and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
|||||||
Reference in New Issue
Block a user