jupyter lab fixes (#1139) [skip ci]

* add a basic notebook for lab users in the root

* update notebook and fix cors for jupyter

* cell is code

* fix eval batch size check

* remove intro notebook
This commit is contained in:
Wing Lian
2024-01-22 18:42:40 -05:00
committed by GitHub
parent f5a828aa20
commit eaaeefce55
6 changed files with 27 additions and 15 deletions

View File

@@ -12,7 +12,7 @@ EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
RUN pip install jupyterlab notebook && \ RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \ mkdir -p ~/.ssh && \

View File

@@ -33,7 +33,7 @@ fi
if [ "$JUPYTER_DISABLE" != "1" ]; then if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background # Run Jupyter Lab in the background
jupyter lab --allow-root --ip 0.0.0.0 & jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi fi
# Execute the passed arguments (CMD) # Execute the passed arguments (CMD)

View File

@@ -3,9 +3,11 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Tuple
import fire import fire
import transformers import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
return do_train(parsed_cfg, parsed_cli_args)
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"per_device_train_batch_size" "per_device_train_batch_size"
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
training_arguments_kwargs[ if self.cfg.eval_batch_size:
"per_device_eval_batch_size" training_arguments_kwargs[
] = self.cfg.eval_batch_size "per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_accumulation_steps" "gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps ] = self.cfg.gradient_accumulation_steps

View File

@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
device = kwargs.get("device", args[0] if args else None) device = kwargs.get("device", args[0] if args else None)
if ( if (
not torch.cuda.is_available() device is None
or not torch.cuda.is_available()
or device == "auto" or device == "auto"
or torch.device(device).type == "cpu" or torch.device(device).type == "cpu"
): ):

View File

@@ -2,7 +2,7 @@
import logging import logging
import math import math
import os import os
from typing import Any, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
@@ -348,7 +348,11 @@ def load_model(
LOG.info("patching _expand_mask") LOG.info("patching _expand_mask")
hijack_expand_mask() hijack_expand_mask()
model_kwargs = {} model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory max_memory = cfg.max_memory
device_map = cfg.device_map device_map = cfg.device_map