diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index a7ad00ff6..36439ff61 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -12,7 +12,7 @@ EXPOSE 22 COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh -RUN pip install jupyterlab notebook && \ +RUN pip install jupyterlab notebook ipywidgets && \ jupyter lab clean RUN apt install --yes --no-install-recommends openssh-server tmux && \ mkdir -p ~/.ssh && \ diff --git a/scripts/cloud-entrypoint.sh b/scripts/cloud-entrypoint.sh index 399eecb81..c7b9ca3e0 100755 --- a/scripts/cloud-entrypoint.sh +++ b/scripts/cloud-entrypoint.sh @@ -33,7 +33,7 @@ fi if [ "$JUPYTER_DISABLE" != "1" ]; then # Run Jupyter Lab in the background - jupyter lab --allow-root --ip 0.0.0.0 & + jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace & fi # Execute the passed arguments (CMD) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e099a1a6d..8dc786c17 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -3,9 +3,11 @@ CLI to run training on a model """ import logging from pathlib import Path +from typing import Tuple import fire import transformers +from transformers import PreTrainedModel, PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, @@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train") def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) - print_axolotl_text_art() - check_accelerate_default_config() - check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + return do_train(parsed_cfg, parsed_cli_args) - if parsed_cfg.rl: - dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + +def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + print_axolotl_text_art() + check_accelerate_default_config() + check_user_token() + if cfg.rl: + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) if __name__ == "__main__": diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b8309c363..bb027acf2 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "per_device_train_batch_size" ] = self.cfg.micro_batch_size - training_arguments_kwargs[ - "per_device_eval_batch_size" - ] = self.cfg.eval_batch_size + if self.cfg.eval_batch_size: + training_arguments_kwargs[ + "per_device_eval_batch_size" + ] = self.cfg.eval_batch_size training_arguments_kwargs[ "gradient_accumulation_steps" ] = self.cfg.gradient_accumulation_steps diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 40be0d9ac..8f33665c6 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -20,7 +20,8 @@ def check_cuda_device(default_value): device = kwargs.get("device", args[0] if args else None) if ( - not torch.cuda.is_available() + device is None + or not torch.cuda.is_available() or device == "auto" or torch.device(device).type == "cpu" ): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c883edb37..fb4caa6d8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,7 +2,7 @@ import logging import math import os -from typing import Any, Optional, Tuple, Union # noqa: F401 +from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict import bitsandbytes as bnb @@ -348,7 +348,11 @@ def load_model( LOG.info("patching _expand_mask") hijack_expand_mask() - model_kwargs = {} + model_kwargs: Dict[str, Any] = {} + + if cfg.model_kwargs: + for key, val in model_kwargs.items(): + model_kwargs[key] = val max_memory = cfg.max_memory device_map = cfg.device_map