jupyter lab fixes (#1139) [skip ci]
* add a basic notebook for lab users in the root * update notebook and fix cors for jupyter * cell is code * fix eval batch size check * remove intro notebook
This commit is contained in:
@@ -12,7 +12,7 @@ EXPOSE 22
|
|||||||
|
|
||||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||||
|
|
||||||
RUN pip install jupyterlab notebook && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
# Run Jupyter Lab in the background
|
# Run Jupyter Lab in the background
|
||||||
jupyter lab --allow-root --ip 0.0.0.0 &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
|
|||||||
@@ -3,9 +3,11 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
|
||||||
check_user_token()
|
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
return do_train(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
if parsed_cfg.rl:
|
|
||||||
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||||
|
print_axolotl_text_art()
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
if cfg.rl:
|
||||||
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
|
||||||
|
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"per_device_train_batch_size"
|
"per_device_train_batch_size"
|
||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs[
|
if self.cfg.eval_batch_size:
|
||||||
"per_device_eval_batch_size"
|
training_arguments_kwargs[
|
||||||
] = self.cfg.eval_batch_size
|
"per_device_eval_batch_size"
|
||||||
|
] = self.cfg.eval_batch_size
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_accumulation_steps"
|
"gradient_accumulation_steps"
|
||||||
] = self.cfg.gradient_accumulation_steps
|
] = self.cfg.gradient_accumulation_steps
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
|
|||||||
device = kwargs.get("device", args[0] if args else None)
|
device = kwargs.get("device", args[0] if args else None)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not torch.cuda.is_available()
|
device is None
|
||||||
|
or not torch.cuda.is_available()
|
||||||
or device == "auto"
|
or device == "auto"
|
||||||
or torch.device(device).type == "cpu"
|
or torch.device(device).type == "cpu"
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -348,7 +348,11 @@ def load_model(
|
|||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
hijack_expand_mask()
|
hijack_expand_mask()
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if cfg.model_kwargs:
|
||||||
|
for key, val in model_kwargs.items():
|
||||||
|
model_kwargs[key] = val
|
||||||
|
|
||||||
max_memory = cfg.max_memory
|
max_memory = cfg.max_memory
|
||||||
device_map = cfg.device_map
|
device_map = cfg.device_map
|
||||||
|
|||||||
Reference in New Issue
Block a user