misc fixes 202507 (#2937) [skip ci]
* misc fixes 202507 * manually handle attn class for llama4
This commit is contained in:
@@ -22,6 +22,7 @@ coverage:
|
|||||||
only_pulls: true
|
only_pulls: true
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
|
informational: true
|
||||||
patch:
|
patch:
|
||||||
default:
|
default:
|
||||||
# basic
|
# basic
|
||||||
|
|||||||
@@ -151,6 +151,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return MllamaTextSelfAttention
|
return MllamaTextSelfAttention
|
||||||
|
|
||||||
|
if model_type == "llama4":
|
||||||
|
from transformers.models.llama4.modeling_llama4 import Llama4TextAttention
|
||||||
|
|
||||||
|
return Llama4TextAttention
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
|||||||
@@ -460,13 +460,13 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
|
|||||||
):
|
):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return load_from_disk(str(prepared_ds_path))
|
return load_from_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user