From 2aa1f714641da62a19a1100299c07b6f41985620 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Oct 2023 08:57:28 -0400 Subject: [PATCH] fix pytorch 2.1.0 build, add multipack docs (#722) --- .github/workflows/main.yml | 1 + docker/Dockerfile | 4 +++ docs/multipack.md | 51 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 docs/multipack.md diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5fdd2d705..f84f7f7a9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -51,6 +51,7 @@ jobs: build-args: | BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} CUDA=${{ matrix.cuda }} + PYTORCH_VERSION=${{ matrix.pytorch }} file: ./docker/Dockerfile push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} diff --git a/docker/Dockerfile b/docker/Dockerfile index 7b121aaa7..ff47548bc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,6 +5,9 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG AXOLOTL_EXTRAS="" ARG CUDA="118" ENV BNB_CUDA_VERSION=$CUDA +ARG PYTORCH_VERSION="2.0.1" + +ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ apt-get install -y vim curl @@ -16,6 +19,7 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets +RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ else \ diff --git a/docs/multipack.md b/docs/multipack.md new file mode 100644 index 000000000..2a55148b2 --- /dev/null +++ b/docs/multipack.md @@ -0,0 +1,51 @@ +# Multipack + +4k context, bsz =4, +each character represents 256 tokens +X represents a padding token + +``` + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +[[ A A A A A A A A A A A ] + B B B B B B ] + C C C C C C C ] + D D D D ]] + +[[ E E E E E E E E ] + [ F F F F ] + [ G G G ] + [ H H H H ]] + +[[ I I I ] + [ J J J ] + [ K K K K K] + [ L L L ]] +``` + +after padding to longest input in each step +``` + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +[[ A A A A A A A A A A A ] + B B B B B B X X X X X X ] + C C C C C C C X X X X ] + D D D D X X X X X X X ]] + +[[ E E E E E E E E ] + [ F F F F X X X X ] + [ G G G X X X X X ] + [ H H H H X X X X ]] + +[[ I I I X X ] + [ J J J X X ] + [ K K K K K ] + [ L L L X X ]] +``` + +w packing ( note it's the same effective number of tokens per step, but a true bsz of 1) +``` + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +[[ A A A A A A A A A A A B B B B B + B C C C C C C C D D D D E E E E + E E E E F F F F F G G G H H H H + I I I J J J J K K K K K L L L X ]] +```