pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading * use latest version of vendored lib * use v0.0.5 of contribs lgpl * fix for no bad tokens and add tests * use release * add multigpu test too * make sure the multigpu zero3 test actually uses zero3
This commit is contained in:
@@ -7,7 +7,7 @@ import signal
|
||||
import sys
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
|
||||
is_ds_zero3: bool = False
|
||||
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||
is_ds_zero3 = True
|
||||
|
||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
|
||||
fix_kwargs: Dict[str, Any] = {}
|
||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
||||
if "is_ds_zero3" in sig.parameters:
|
||||
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
||||
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user