jupyter lab fixes (#1139) [skip ci]
* add a basic notebook for lab users in the root * update notebook and fix cors for jupyter * cell is code * fix eval batch size check * remove intro notebook
This commit is contained in:
@@ -12,7 +12,7 @@ EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
|
||||
RUN pip install jupyterlab notebook && \
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
mkdir -p ~/.ssh && \
|
||||
|
||||
@@ -33,7 +33,7 @@ fi
|
||||
|
||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||
# Run Jupyter Lab in the background
|
||||
jupyter lab --allow-root --ip 0.0.0.0 &
|
||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
|
||||
@@ -3,9 +3,11 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
if parsed_cfg.rl:
|
||||
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if cfg.rl:
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
if self.cfg.eval_batch_size:
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
|
||||
@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
|
||||
device = kwargs.get("device", args[0] if args else None)
|
||||
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
device is None
|
||||
or not torch.cuda.is_available()
|
||||
or device == "auto"
|
||||
or torch.device(device).type == "cpu"
|
||||
):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Optional, Tuple, Union # noqa: F401
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
@@ -348,7 +348,11 @@ def load_model(
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
|
||||
model_kwargs = {}
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if cfg.model_kwargs:
|
||||
for key, val in model_kwargs.items():
|
||||
model_kwargs[key] = val
|
||||
|
||||
max_memory = cfg.max_memory
|
||||
device_map = cfg.device_map
|
||||
|
||||
Reference in New Issue
Block a user