install flash-linear-attention (#3466)
* install flash-linear-attention * handle prequant weights for fsdp2 and ensure loss is not zero * fix type for cu_seqlen, uninstall causal_conv1d * chore: lint * uv pip uninstall doesn't need confirmation
This commit is contained in:
@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
|
||||
@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
|
||||
@@ -20,6 +20,9 @@ trl==0.29.0
|
||||
hf_xet==1.3.2
|
||||
kernels==0.12.2
|
||||
|
||||
fla-core==0.4.1
|
||||
flash-linear-attention==0.4.1
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
|
||||
11
setup.py
11
setup.py
@@ -27,9 +27,16 @@ def parse_requirements(extras_require_map):
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
install_xformers = platform.machine() != "aarch64"
|
||||
if platform.machine() == "aarch64":
|
||||
# skip torchao on ARM64
|
||||
# skip on ARM64
|
||||
skip_packages = [
|
||||
"torchao",
|
||||
"fla-core",
|
||||
"flash-linear-attention",
|
||||
]
|
||||
_install_requires = [
|
||||
req for req in _install_requires if "torchao" not in req
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
|
||||
@@ -506,8 +506,11 @@ def patch_initialize_missing_keys_for_fsdp():
|
||||
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
for key in self.state_dict():
|
||||
param_or_buffer = self.get_parameter_or_buffer(key)
|
||||
param_or_buffer._is_hf_initialized = True
|
||||
try:
|
||||
param_or_buffer = self.get_parameter_or_buffer(key)
|
||||
param_or_buffer._is_hf_initialized = True
|
||||
except AttributeError:
|
||||
pass # may happen when handling pre-quantized weights
|
||||
self._is_hf_initialized = True
|
||||
|
||||
_original_initialize_missing_keys(self, is_quantized)
|
||||
|
||||
@@ -180,6 +180,7 @@ def check_tensorboard(
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
gt_zero: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -194,6 +195,8 @@ def check_tensorboard(
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
if gt_zero:
|
||||
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
Reference in New Issue
Block a user