diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 6988e092b..9cad43f40 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN pip install packaging==23.2 setuptools==75.8.0 RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,llmcompressor,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,llmcompressor] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/docker/Dockerfile b/docker/Dockerfile index e23a729d4..bac02c057 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,llmcompressor,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,llmcompressor] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4c6d77144..808d3af64 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -288,7 +288,19 @@ def save_trained_model( os.remove(os.path.join(cfg.output_dir, "model.safetensors")) except FileNotFoundError: pass - elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor: + elif cfg.local_rank == 0: + if cfg.flash_optimum and BetterTransformer: + model = BetterTransformer.reverse(model) + + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) + + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: + # TODO: add integration support so this can be implemented completely within the plugin from axolotl.integrations.llm_compressor.utils import ( save_compressed_model, ) @@ -301,17 +313,6 @@ def save_trained_model( save_compressed=cfg.llmcompressor.save_compressed, ) - elif cfg.local_rank == 0: - if cfg.flash_optimum and BetterTransformer: - model = BetterTransformer.reverse(model) - - if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: - trainer.model.save_pretrained( - cfg.output_dir, safe_serialization=safe_serialization - ) - - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - def create_model_card(cfg: DictDefault, trainer: Trainer): """