Compare commits
2 Commits
34de5b3bd5
...
sdpa-multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a538be9c2 | ||
|
|
74c72ca5eb |
21
.github/workflows/base.yml
vendored
21
.github/workflows/base.yml
vendored
@@ -1,10 +1,7 @@
|
|||||||
name: ci-cd-base
|
name: ci-cd-base
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
workflow_dispatch:
|
||||||
branches:
|
|
||||||
- "main-base"
|
|
||||||
- "dev-base"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
@@ -15,11 +12,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "118"
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.9"
|
|
||||||
pytorch: 2.0.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -28,12 +20,17 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -56,7 +53,7 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ./docker/Dockerfile-base
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
|||||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -4,6 +4,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
@@ -15,24 +16,24 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
@@ -86,24 +87,24 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
|
|||||||
# Check that the output matches the expected output
|
# Check that the output matches the expected output
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
def test_output_multipack(self):
|
||||||
|
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
|
||||||
|
dtype = torch.float32
|
||||||
|
expected_output = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Check that the output matches the expected output
|
||||||
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user