diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 571faf771..5a3f90992 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -1,4 +1,4 @@ -name: ci-cd +name: ci-cd-base on: push: diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 4f4431dbe..510d038ec 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -62,6 +62,7 @@ RUN git clone https://github.com/microsoft/DeepSpeed.git && \ FROM base-builder AS bnb-builder WORKDIR /workspace +ENV CUDA_VERSION_BNB=$CUDA_VERSION_BNB RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \ cd bitsandbytes && \ @@ -70,6 +71,8 @@ RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \ FROM base-builder +ENV CUDA_VERSION_BNB=$CUDA_VERSION_BNB + # recompile apex RUN python3 -m pip uninstall -y apex RUN git clone https://github.com/NVIDIA/apex diff --git a/scripts/finetune.py b/scripts/finetune.py index 1d1eb9f95..58f1c0957 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -178,6 +178,15 @@ def train( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) + if cfg.debug or "debug" in kwargs: + logging.info("check_dataset_labels...") + check_dataset_labels( + train_dataset.select( + [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + ), + tokenizer, + ) + if prepare_ds_only: logging.info("Finished preparing dataset. Exiting...") return @@ -213,15 +222,6 @@ def train( model.save_pretrained(cfg.output_dir) return - if cfg.debug: - logging.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] - ), - tokenizer, - ) - trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) model.config.use_cache = False diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bfe6fc877..a91a4e2d3 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -268,6 +268,9 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): + def get_conversation_thread(self, prompt): + return prompt["conversations"] + def tokenize_prompt(self, prompt): result = { "input_ids": [], @@ -279,7 +282,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): assistant_token = self._get_assistant_token() try: for i, part in enumerate( - self.prompter.build_prompt(prompt["conversations"]) + self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): if part[0] == "USER:":