install flash-linear-attention (#3466)

* install flash-linear-attention

* handle prequant weights for fsdp2 and ensure loss is not zero

* fix type for cu_seqlen, uninstall causal_conv1d

* chore: lint

* uv pip uninstall doesn't need confirmation
This commit is contained in:
Wing Lian
2026-03-06 12:40:57 -05:00
committed by GitHub
parent d65e1b960c
commit 876941ffd0
8 changed files with 24 additions and 4 deletions

View File

@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \

View File

@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \

View File

@@ -20,6 +20,9 @@ trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
fla-core==0.4.1
flash-linear-attention==0.4.1
trackio>=0.16.1
typing-extensions>=4.15.0

View File

@@ -27,9 +27,16 @@ def parse_requirements(extras_require_map):
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
_install_requires = [
req for req in _install_requires if "torchao" not in req
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX

View File

@@ -506,8 +506,11 @@ def patch_initialize_missing_keys_for_fsdp():
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key in self.state_dict():
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
try:
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
except AttributeError:
pass # may happen when handling pre-quantized weights
self._is_hf_initialized = True
_original_initialize_missing_keys(self, is_quantized)

View File

@@ -180,6 +180,7 @@ def check_tensorboard(
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
gt_zero: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -194,6 +195,8 @@ def check_tensorboard(
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
assert df.value.values[-1] < lt_val, assertion_err
if gt_zero:
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: