check for the existence of the default accelerate config that can create headaches (#561)
This commit is contained in:
@@ -14,6 +14,7 @@ import transformers
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
|
from accelerate.commands.config import config_args
|
||||||
from art import text2art
|
from art import text2art
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
@@ -254,9 +255,17 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_accelerate_default_config():
|
||||||
|
if Path(config_args.default_yaml_config_file).exists():
|
||||||
|
LOG.warning(
|
||||||
|
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
check_accelerate_default_config()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
|
|||||||
Reference in New Issue
Block a user