diff --git a/cicd/Dockerfile-uv.jinja b/cicd/Dockerfile-uv.jinja index 9a49cfca5..103f1eb99 100644 --- a/cicd/Dockerfile-uv.jinja +++ b/cicd/Dockerfile-uv.jinja @@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN uv pip install packaging==26.0 setuptools==75.8.0 RUN uv pip install torchvision +RUN uv pip uninstall causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 1c397b011..13d2f4e69 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ fi RUN pip install packaging==26.0 setuptools==75.8.0 psutil +RUN pip uninstall -y causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/docker/Dockerfile b/docker/Dockerfile index d80cede55..5840c1f61 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 +RUN pip uninstall -y causal_conv1d RUN if [ "$TARGETARCH" = "arm64" ]; then \ BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ else \ diff --git a/docker/Dockerfile-uv b/docker/Dockerfile-uv index aa03f22a8..0142c0d2d 100644 --- a/docker/Dockerfile-uv +++ b/docker/Dockerfile-uv @@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 +RUN uv pip uninstall causal_conv1d RUN if [ "$TARGETARCH" = "arm64" ]; then \ BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ else \ diff --git a/requirements.txt b/requirements.txt index 472a98bc8..c918d30aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,9 @@ trl==0.29.0 hf_xet==1.3.2 kernels==0.12.2 +fla-core==0.4.1 +flash-linear-attention==0.4.1 + trackio>=0.16.1 typing-extensions>=4.15.0 diff --git a/setup.py b/setup.py index fe00b9e14..5b7b50f29 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,16 @@ def parse_requirements(extras_require_map): xformers_version = [req for req in _install_requires if "xformers" in req][0] install_xformers = platform.machine() != "aarch64" if platform.machine() == "aarch64": - # skip torchao on ARM64 + # skip on ARM64 + skip_packages = [ + "torchao", + "fla-core", + "flash-linear-attention", + ] _install_requires = [ - req for req in _install_requires if "torchao" not in req + req + for req in _install_requires + if re.split(r"[>=<]", req)[0].strip() not in skip_packages ] if "Darwin" in platform.system(): # skip packages not compatible with OSX diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 87b537655..cf8056ca0 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -506,8 +506,11 @@ def patch_initialize_missing_keys_for_fsdp(): def _patched_initialize_missing_keys(self, is_quantized: bool) -> None: if is_fsdp_enabled() and not is_local_dist_rank_0(): for key in self.state_dict(): - param_or_buffer = self.get_parameter_or_buffer(key) - param_or_buffer._is_hf_initialized = True + try: + param_or_buffer = self.get_parameter_or_buffer(key) + param_or_buffer._is_hf_initialized = True + except AttributeError: + pass # may happen when handling pre-quantized weights self._is_hf_initialized = True _original_initialize_missing_keys(self, is_quantized) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 842cbf118..268311295 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -180,6 +180,7 @@ def check_tensorboard( lt_val: float, assertion_err: str, rtol: float = 0.02, + gt_zero: bool = True, ) -> None: """ helper function to parse and check tensorboard logs @@ -194,6 +195,8 @@ def check_tensorboard( assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] else: assert df.value.values[-1] < lt_val, assertion_err + if gt_zero: + assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero" def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: