jupyter lab fixes (#1139) [skip ci]

* add a basic notebook for lab users in the root

* update notebook and fix cors for jupyter

* cell is code

* fix eval batch size check

* remove intro notebook
This commit is contained in:
Wing Lian
2024-01-22 18:42:40 -05:00
committed by GitHub
parent f5a828aa20
commit eaaeefce55
6 changed files with 27 additions and 15 deletions

View File

@@ -3,9 +3,11 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple
import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps

View File

@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
device = kwargs.get("device", args[0] if args else None)
if (
not torch.cuda.is_available()
device is None
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):

View File

@@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
@@ -348,7 +348,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs = {}
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
device_map = cfg.device_map