jupyter lab fixes (#1139) [skip ci]

* add a basic notebook for lab users in the root

* update notebook and fix cors for jupyter

* cell is code

* fix eval batch size check

* remove intro notebook
This commit is contained in:
Wing Lian
2024-01-22 18:42:40 -05:00
committed by GitHub
parent f5a828aa20
commit eaaeefce55
6 changed files with 27 additions and 15 deletions

View File

@@ -12,7 +12,7 @@ EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
RUN pip install jupyterlab notebook && \
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \

View File

@@ -33,7 +33,7 @@ fi
if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --allow-root --ip 0.0.0.0 &
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi
# Execute the passed arguments (CMD)

View File

@@ -3,9 +3,11 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple
import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps

View File

@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
device = kwargs.get("device", args[0] if args else None)
if (
not torch.cuda.is_available()
device is None
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):

View File

@@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
@@ -348,7 +348,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs = {}
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
device_map = cfg.device_map