diff --git a/requirements.txt b/requirements.txt index 60b07a824..ee808de76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ art fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe gradio==3.50.2 tensorboard +python-dotenv==1.0.1 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 86ad8409f..adc991456 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -5,6 +5,7 @@ from pathlib import Path import fire import transformers +from dotenv import load_dotenv from axolotl.cli import ( do_inference, @@ -33,4 +34,5 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): if __name__ == "__main__": + load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 8db3fa989..6588b5ee4 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -5,6 +5,7 @@ from pathlib import Path import fire import transformers +from dotenv import load_dotenv from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.common.cli import TrainerCliArgs @@ -48,4 +49,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs): if __name__ == "__main__": + load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index f43277e49..5ec279d4b 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -9,6 +9,7 @@ import fire import transformers from accelerate import init_empty_weights from colorama import Fore +from dotenv import load_dotenv from transformers import AutoModelForCausalLM from axolotl.cli import ( @@ -86,4 +87,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): if __name__ == "__main__": + load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index 48f22790a..196c0e99a 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -7,6 +7,7 @@ from typing import Union import fire import transformers +from dotenv import load_dotenv from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer @@ -40,4 +41,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): if __name__ == "__main__": + load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7bb4a5184..050f18a05 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Tuple, Union import fire +from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer @@ -67,4 +68,5 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: if __name__ == "__main__": + load_dotenv() fire.Fire(do_cli)