diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 8f67908e8..fcfd96891 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o 5. Push your branch to your fork on GitHub. 6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues. +#### Skipping CI Checks + +You can skip certain CI checks by including specific keywords in your commit messages: + +- `[skip ci]` or `skip ci` - Skips all CI checks for that commit +- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR. + ## Style Guidelines ### Code Style diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 912b3f1d6..fe63aa313 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -188,13 +188,44 @@ jobs: run: | find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + gate-skip-e2e: + needs: [pre-commit, pytest, pytest-sdist] + runs-on: ubuntu-latest + outputs: + skip: ${{ steps.compute.outputs.skip }} + steps: + - uses: actions/github-script@v7 + id: compute + with: + script: | + const token = /\[skip-e2e\]/i; + let msg = ''; + if (context.eventName === 'push') { + msg = context.payload.head_commit?.message || ''; + } else if (context.eventName === 'pull_request') { + const { owner, repo } = context.repo; + const prNumber = context.payload.pull_request.number; + const commits = await github.paginate( + github.rest.pulls.listCommits, + { owner, repo, pull_number: prNumber, per_page: 100 } + ); + msg = commits.at(-1)?.commit?.message || ''; + } + const title = context.payload.pull_request?.title || ''; + const body = context.payload.pull_request?.body || ''; + const skip = token.test(msg) || token.test(title) || token.test(body); + core.setOutput('skip', String(skip)); + docker-e2e-tests-1st: # Run this job first as a gate for running the remainder of the test matrix - if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }} + if: > + github.repository_owner == 'axolotl-ai-cloud' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) && + needs.gate-skip-e2e.outputs.skip != 'true' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] timeout-minutes: 120 - needs: [pre-commit, pytest, pytest-sdist] + needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e] strategy: fail-fast: false @@ -240,13 +271,16 @@ jobs: modal run cicd.e2e_tests docker-e2e-tests: - if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }} + if: > + github.repository_owner == 'axolotl-ai-cloud' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) && + needs.gate-skip-e2e.outputs.skip != 'true' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] timeout-minutes: 120 # Only run the remainder of the matrix if the first e2e check passed; # this is to save on wasted compute costs for known failures that get caught in the first run - needs: [pre-commit, pytest, docker-e2e-tests-1st] + needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st] strategy: fail-fast: false diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index d1151cedd..87918cc41 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -37,7 +37,7 @@ WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ - CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ + CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ python3 -m pip cache purge diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index dbb365f73..d839ce211 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -13,10 +13,13 @@ format: - [Pixtral](#sec-pixtral) - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) +- [Voxtral](#sec-voxtral) - [Gemma-3](#sec-gemma-3) - [Gemma-3n](#sec-gemma-3n) - [Qwen2-VL](#sec-qwen2-vl) - [Qwen2.5-VL](#sec-qwen25-vl) +- [SmolVLM2](#sec-smolvlm2) +- [LFM2-VL](#sec-lfm2-vl) ## Usage @@ -31,7 +34,7 @@ skip_prepare_dataset: true remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training sample_packing: false # not yet supported with multimodal -chat_template: # see in next section +chat_template: # see in next section if specified # example dataset datasets: @@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 chat_template: mistral_v7_tekken ``` +### Voxtral {#sec-voxtral} + +::: {.callout-tip} +Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'` +::: + +```yaml +base_model: mistralai/Voxtral-Mini-3B-2507 +``` + ### Gemma-3 {#sec-gemma-3} ::: {.callout-tip} @@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct chat_template: qwen2_vl # same as qwen2-vl ``` +### SmolVLM2 {#sec-smolvlm2} + +::: {.callout-tip} +Please make sure to install `num2words` via `pip3 install num2words==0.5.14` +::: + +```yaml +base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct +``` + +### LFM2-VL {#sec-lfm2-vl} + +::: {.callout-warning} +Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d` +::: + +```yaml +base_model: LiquidAI/LFM2-VL-450M +``` + ## Dataset Format For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. @@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`. ::: +### Video + +::: {.callout-warning} + +This is not well tested at the moment. We welcome contributors! + +::: + +For video loading, you can use the following keys within `content` alongside `"type": "video"`: + +- `"path": "/path/to/video.mp4"` +- `"url": "https://example.com/video.mp4"` +- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned) + ### Example Here is an example of a multi-modal dataset: diff --git a/examples/LiquidAI/README.md b/examples/LiquidAI/README.md new file mode 100644 index 000000000..96fc74a92 --- /dev/null +++ b/examples/LiquidAI/README.md @@ -0,0 +1,58 @@ +# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl + +[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models. + +LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference. + +This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Run one of the finetuning examples below. + + **LFM2** + ```bash + # FFT SFT (1x48GB @ 25GiB) + axolotl train examples/LiquidAI/lfm2-350m-fft.yaml + ``` + + **LFM2-VL** + ```bash + # LoRA SFT (1x48GB @ 2.7GiB) + axolotl train examples/LiquidAI/lfm2-vl-lora.yaml + ``` + +### TIPS + +- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it: + ```bash + pip uninstall -y causal-conv1d + ``` + +- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). +- **Dataset Formats**: + - For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + - For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details. + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) +- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/LiquidAI/lfm2-350m-fft.yaml similarity index 96% rename from examples/lfm2/lfm2-350m-fft.yaml rename to examples/LiquidAI/lfm2-350m-fft.yaml index 16a0a028e..d19815008 100644 --- a/examples/lfm2/lfm2-350m-fft.yaml +++ b/examples/LiquidAI/lfm2-350m-fft.yaml @@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M chunked_cross_entropy: true -chat_template: tokenizer_default eot_tokens: - "<|im_end|>" datasets: diff --git a/examples/LiquidAI/lfm2-vl-lora.yaml b/examples/LiquidAI/lfm2-vl-lora.yaml new file mode 100644 index 000000000..7fee17f92 --- /dev/null +++ b/examples/LiquidAI/lfm2-vl-lora.yaml @@ -0,0 +1,58 @@ +base_model: LiquidAI/LFM2-VL-450M +trust_remote_code: true +model_type: AutoModelForImageTextToText +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md index 6dadb8230..9db5e9887 100644 --- a/examples/gpt-oss/README.md +++ b/examples/gpt-oss/README.md @@ -33,13 +33,44 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs. ### Training 120B -On 8xH100s +On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base +model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints. ```bash # FFT SFT with offloading (8x80GB @ ~49GiB/GPU) axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml ``` +ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. +See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. + +```bash +sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json +``` + +When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your +configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to +merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded +weights to `{output_dir}/merged`. + +```bash +axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/ +``` + + +### Inferencing your fine-tuned model + +GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 +for more information about using a special vllm-openai docker image for inferencing with vLLM. + +SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing +SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: + +```bash +python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8 +``` + ### Tool use GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning. diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index 4a9d51fdf..4b4fbd89b 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -20,6 +20,7 @@ datasets: dataset_prepared_path: last_run_prepared val_set_size: 0 output_dir: ./outputs/gpt-oss-out/ +save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2 sequence_len: 4096 sample_packing: true diff --git a/examples/lfm2/README.md b/examples/lfm2/README.md deleted file mode 100644 index eb9ca911f..000000000 --- a/examples/lfm2/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Liquid Foundation Models 2 - -LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release. - -```bash -pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git -``` diff --git a/examples/smolvlm2/README.md b/examples/smolvlm2/README.md new file mode 100644 index 000000000..9c0ae4836 --- /dev/null +++ b/examples/smolvlm2/README.md @@ -0,0 +1,49 @@ +# Finetune SmolVLM2 with Axolotl + +[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content. + +These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M. + +This guide shows how to fine-tune SmolVLM2 models with Axolotl. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Install an extra dependency: + + ```bash + pip3 install num2words==0.5.14 + ``` + +3. Run the finetuning example: + + ```bash + # LoRA SFT (1x48GB @ 6.8GiB) + axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml + ``` + +## TIPS + +- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). +- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/smolvlm2/smolvlm2-2B-lora.yaml b/examples/smolvlm2/smolvlm2-2B-lora.yaml new file mode 100644 index 000000000..1aeff408d --- /dev/null +++ b/examples/smolvlm2/smolvlm2-2B-lora.yaml @@ -0,0 +1,56 @@ +base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct +trust_remote_code: true +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/requirements.txt b/requirements.txt index 5f7767812..c2552002f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ packaging==23.2 huggingface_hub>=0.33.0 peft==0.17.0 -transformers==4.55.1 +transformers==4.55.2 tokenizers>=0.21.1 accelerate==1.10.0 datasets==4.0.0 diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 31d854d41..9bb544aff 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -40,6 +40,12 @@ class VllmServeCliArgs: default=None, metadata={"help": "Number of tensor parallel workers to use."}, ) + data_parallel_size: Optional[int] = field( + default=None, + metadata={ + "help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference." + }, + ) host: Optional[str] = field( default=None, # nosec B104 metadata={"help": "Host address to run the server on."}, diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index c08d30ec8..c99f37fb1 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -10,6 +10,7 @@ import fire import torch import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +from accelerate import PartialState from accelerate.utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from axolotl.cli.config import load_cfg from axolotl.utils.logging import get_logger +from axolotl.utils.train import determine_last_checkpoint LOG = get_logger(__name__) @@ -143,7 +145,6 @@ def merge_fsdp_weights( ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) - from accelerate.state import PartialState if not is_torch_version(">=", "2.3.0"): raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") @@ -180,7 +181,6 @@ def merge_fsdp_weights( if remove_checkpoint_dir: LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") shutil.rmtree(checkpoint_dir_) - state.wait_for_everyone() def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): @@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False) + if checkpoint_dir: + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + raise ValueError( + f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}" + ) + + output_path = str(Path(parsed_cfg.output_dir) / "merged") merge_fsdp_weights( checkpoint_dir=str(fsdp_dir), - output_path=str(Path(parsed_cfg.output_dir) / "merged"), + output_path=output_path, safe_serialization=True, ) + state = PartialState() + state.wait_for_everyone() + LOG.info( + f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}", + main_process_only=True, + ) + LOG.info( + "Merged weights are only the safetensors and doesn't include the model configuration " + f"or tokenizer which may be found in {parsed_cfg.output_dir}.", + main_process_only=True, + ) if __name__ == "__main__": diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 5f97e387a..a9cda4efc 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -5,7 +5,6 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer from .trl import ( AxolotlCPOTrainer, diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py index 3fabf9d94..4939cb28d 100644 --- a/src/axolotl/loaders/constants.py +++ b/src/axolotl/loaders/constants.py @@ -1,26 +1,13 @@ """Shared constants for axolotl.loaders module""" -from transformers import ( - Gemma3ForConditionalGeneration, - Gemma3nForConditionalGeneration, - Llama4ForConditionalGeneration, - LlavaForConditionalGeneration, - Mistral3ForConditionalGeneration, - MllamaForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration, +from transformers import AutoModelForImageTextToText +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, ) -MULTIMODAL_AUTO_MODEL_MAPPING = { - "mllama": MllamaForConditionalGeneration, - "llama4": Llama4ForConditionalGeneration, - "llava": LlavaForConditionalGeneration, - "qwen2_vl": Qwen2VLForConditionalGeneration, - "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, - "mistral3": Mistral3ForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, - "gemma3n": Gemma3nForConditionalGeneration, -} +MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) + +MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText try: from transformers import VoxtralForConditionalGeneration diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1cb33d13c..384fbdf25 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -25,6 +25,7 @@ from peft import ( from torch.distributed import DeviceMesh from transformers import ( AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForVision2Seq, AwqConfig, BitsAndBytesConfig, @@ -212,6 +213,7 @@ class ModelLoader: self.model_kwargs["use_kernels"] = self.cfg.use_kernels self._set_quantization_config() self._set_attention_config() + self._check_model_requirements() def _apply_post_model_load_setup(self): """Configure the model after it has been loaded.""" @@ -432,6 +434,8 @@ class ModelLoader: self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( self.model_config.model_type, AutoModelForVision2Seq ) + if isinstance(self.auto_model_loader, str): + self.auto_model_loader = AutoModelForImageTextToText def _set_device_map_config(self): """Setup `device_map` according to config""" @@ -628,6 +632,16 @@ class ModelLoader: if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True + def _check_model_requirements(self): + if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]: + from transformers.utils.import_utils import is_causal_conv1d_available + + if is_causal_conv1d_available(): + raise ImportError( + "The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. " + "Please uninstall it by running: `pip uninstall -y causal-conv1d`" + ) + def _configure_zero3_memory_efficient_loading( self, ) -> HfTrainerDeepSpeedConfig | None: diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 4cc5e85a1..31597d5a6 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -6,7 +6,7 @@ from typing import Optional from PIL import Image, ImageOps from PIL.Image import Resampling from torch import Tensor, zeros_like -from transformers import ProcessorMixin, VoxtralProcessor +from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor from transformers.image_utils import load_image from axolotl.utils.dict import remove_none_values @@ -138,7 +138,7 @@ class ProcessingStrategy: image_key = key break - # if the image key exists, add the image to the first message + # if the image key exists, add the image to the first user message if image_key is not None and processed_example[image_key] is not None: # TODO: check if it's normal to be single image only for common datasets # From observation, it's usually a list of single image but some datasets may have several columns for images @@ -179,26 +179,34 @@ class ProcessingStrategy: # Look for any image type in the first message # some dataset have an {type: "image"} in the first message + msg_ind_to_add = None ind_to_add = None + first_user_idx = None - for i, content in enumerate( - processed_example["messages"][0]["content"] - ): - # Usually datasets created with image columns, don't have it in the messages itself - if content["type"] == "image" and all( - k not in content for k in ["image", "url", "path", "base64"] + for msg_idx, msg_content in enumerate(processed_example["messages"]): + if first_user_idx is None and msg_content["role"] == "user": + first_user_idx = msg_idx + for i, content in enumerate( + processed_example["messages"][msg_idx]["content"] ): - ind_to_add = i - break + # Usually datasets created with image columns, don't have it in the messages itself + if content["type"] == "image" and all( + k not in content for k in ["image", "url", "path", "base64"] + ): + msg_ind_to_add = msg_idx + ind_to_add = i + break # If an image type is found, add the image to that index - if ind_to_add is not None: - processed_example["messages"][0]["content"][ind_to_add][ - "image" - ] = image_value + if ind_to_add is not None and msg_ind_to_add is not None: + processed_example["messages"][msg_ind_to_add]["content"][ + ind_to_add + ]["image"] = image_value else: - # if no image type is found, add it to end of the first message - processed_example["messages"][0]["content"].append( + # if no image type is found, add it to end of the first user message + if first_user_idx is None: + first_user_idx = 0 + processed_example["messages"][first_user_idx]["content"].append( { "type": "image", "image": image_value, @@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy): return labels +class SmolVLM2ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for SmolVLM2""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = "" # nosec + + self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index(self.image_token) + ] + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -402,32 +428,43 @@ def get_processing_strategy( image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, ): + processing_kwargs = { + "processor": processor, + "chat_template": chat_template, + "image_size": image_size, + "image_resize_algorithm": image_resize_algorithm, + } + + if chat_template_type in [None, "tokenizer_default"] and hasattr( + processor.tokenizer, "chat_template" + ): + processing_kwargs["chat_template"] = processor.tokenizer.chat_template + if chat_template_type == "qwen2_vl": return Qwen2VLProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if chat_template_type == "gemma3": return Gemma3ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if chat_template_type == "gemma3n": return Gemma3nProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm - ) - if chat_template_type in [ - "llama3_2_vision", - "llama4", - "llava", - "mistral_v7_tekken", - "pixtral", - ]: - return ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if isinstance(processor, VoxtralProcessor): return VoxtralProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) - raise ValueError(f"Unsupported chat template type: {chat_template_type}") + if isinstance(processor, SmolVLMProcessor): + return SmolVLM2ProcessingStrategy( + **processing_kwargs, + ) + + # llama3_2_vision, llama4, llava + # mistral_v7_tekken, pixtral, lfm2vl + return ProcessingStrategy( + **processing_kwargs, + ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8241dd385..f927b7fcb 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter): images=images, return_tensors="pt", ) + if hasattr(batch, "to_dict"): + batch = batch.to_dict() + else: + batch = dict(batch) + # workaround since processor works in batches instead of single examples + out = {} for k, val in batch.items(): - if k in ["pixel_values"]: - batch[k] = val.tolist() + if hasattr(val, "tolist"): + out[k] = ( + val.tolist() if k == "pixel_values" else val.squeeze(0).tolist() + ) else: - batch[k] = val.squeeze().tolist() - return batch + out[k] = val + return out return self.tokenizer.apply_chat_template( conversation, @@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + tokenized_prompt = dict(tokenized_res) if not self.train_on_inputs: - user_prompt_len = len(prompt_ids) + if isinstance(prompt_ids, dict): + user_prompt_len = len(prompt_ids["input_ids"]) + else: + user_prompt_len = len(prompt_ids) labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] else: labels = input_ids diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e8a2cbabe..8005389f1 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -4,11 +4,14 @@ from __future__ import annotations import importlib import inspect +import json import os +import shutil import signal import sys import typing import weakref +from collections import OrderedDict from contextlib import ExitStack from pathlib import Path from typing import Any, Dict @@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType +from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.trainer import setup_trainer try: @@ -46,7 +50,7 @@ except ImportError: BetterTransformer = None if typing.TYPE_CHECKING: - from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder LOG = get_logger(__name__) @@ -124,32 +128,6 @@ def setup_reference_model( return model_ref -def determine_resume_checkpoint(cfg: DictDefault) -> str | None: - """ - Determine the checkpoint to resume from based on configuration. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - - Returns: - Path to the checkpoint to resume from, or `None` if not resuming. - """ - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - return cfg.resume_from_checkpoint - - def setup_signal_handler( cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool ): @@ -282,12 +260,49 @@ def save_trained_model( else: state_dict_type = cfg.fsdp_config.state_dict_type trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) - trainer.save_model(cfg.output_dir) + trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT if state_dict_type == "SHARDED_STATE_DICT": LOG.info( "The final model was saved with a sharded state dict. Please ensure you merge " "the sharded weights with `merge-sharded-fsdp-weights`." ) + checkpoint_dir = determine_last_checkpoint(cfg, update=False) + if ( + not (Path(cfg.output_dir) / "model.safetensors.index.json").exists() + and checkpoint_dir + ): + # import here to prevent circular import + from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights + + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + merged_path = str(Path(cfg.output_dir) / "merged") + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=merged_path, + safe_serialization=True, + ) + trainer.accelerator.wait_for_everyone() + if trainer.accelerator.is_main_process: + # move all files in merged_path to cfg.output_dir + for merged_file in Path(merged_path).iterdir(): + shutil.move(str(merged_file), cfg.output_dir) + shutil.rmtree(merged_path) # remove what should be an empty dir + # TODO(wing):see https://github.com/huggingface/transformers/pull/40207 + # cleanup the FSDP prefix in the model config.json + if trainer.accelerator.is_main_process: + with open( + Path(cfg.output_dir) / "config.json", "r", encoding="utf-8" + ) as config_file_io: + # read the model config as an OrderedDict + config = json.load(config_file_io, object_pairs_hook=OrderedDict) + config["architectures"] = [ + name.lstrip("FSDP") for name in config["architectures"] + ] + # write the updated model config back + with open( + os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8" + ) as config_file_io: + json.dump(config, config_file_io, indent=2) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() @@ -564,7 +579,7 @@ def train( setup_model_card(cfg) # Execute the training - resume_from_checkpoint = determine_resume_checkpoint(cfg) + resume_from_checkpoint = determine_last_checkpoint(cfg) execute_training(cfg, trainer, resume_from_checkpoint) # clear cache diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 0075d4830..542918527 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing from dataclasses import dataclass from typing import Any, Optional, Union -import torch from torch import Tensor from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin @@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin): examples = self.processing_strategy(examples) # Initialize batch - batch: dict[str, Any] = {} + messages = [ex["messages"] for ex in examples] - # Process each example - for example in examples: - # Apply chat template to process the example - # This method requires transformers>=4.49.0 - result = self.processing_strategy.processor.apply_chat_template( - example["messages"], - add_generation_prompt=False, - tokenize=True, - return_tensors="pt", - padding=True, - return_dict=True, - chat_template=self.processing_strategy.chat_template, - ) - - # TODO: Check if need handling for len(input_ids) > sequence_len - - # Add the processed tensors to our batch - for key in result.keys(): - if key not in batch: - batch[key] = [] - - batch[key].append(result[key].squeeze(0)) - - # Pad sequences to the same length - input_ids = torch.nn.utils.rnn.pad_sequence( - batch["input_ids"], - batch_first=True, - padding_value=self.tokenizer.pad_token_id, + batch = self.processing_strategy.processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_tensors="pt", + padding=True, + return_dict=True, + chat_template=self.processing_strategy.chat_template, ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - batch["attention_mask"], batch_first=True, padding_value=0 - ) - - # Create the final batch - final_batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - for key, val in batch.items(): - if key in ["input_ids", "attention_mask"]: - continue - - if key in ["token_type_ids", "cross_attention_mask"]: - final_batch[key] = torch.nn.utils.rnn.pad_sequence( - val, batch_first=True, padding_value=0 - ) - else: - final_batch[key] = torch.stack(val) - # Process the labels - final_batch["labels"] = self.processing_strategy.process_labels( - final_batch["input_ids"] - ) + batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) - return final_batch + return batch diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 975f26e71..2ae7d9052 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -28,7 +28,7 @@ from axolotl.utils.data.shared import ( ) from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, - drop_long_seq_in_dataset, + handle_long_seq_in_dataset, retry_on_request_exceptions, ) from axolotl.utils.data.wrappers import get_dataset_wrapper @@ -339,9 +339,9 @@ def _load_raw_datasets( if not cfg.skip_prepare_dataset: if split == "test" and cfg.eval_sequence_len: - dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) + dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) else: - dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) + dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index c0efb7a42..856a609c7 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -148,7 +148,36 @@ def deduplicate_and_log_datasets( return dataset, other_dataset -def drop_long_seq_in_dataset( +def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): + """ + Truncate samples whose sequence length is too long (> sequence_len) + or drop those too short (< min_sequence_len). + """ + min_sequence_len = min_sequence_len or 2 + + input_ids = sample["input_ids"] + results = [] + + # Batched (input_ids is a list of lists) + for i, seq in enumerate(input_ids): + length = len(seq) + if length < min_sequence_len: + results.append(False) + elif length > sequence_len: + sample["input_ids"][i] = seq[:sequence_len] + if "attention_mask" in sample: + sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] + if "labels" in sample: + sample["labels"][i] = sample["labels"][i][:sequence_len] + if "position_ids" in sample: + sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] + results.append(True) + else: + results.append(True) + return results + + +def handle_long_seq_in_dataset( dataset: Dataset, sequence_len: int, cfg: DictDefault ) -> Dataset: """Remove sequences longer than configured maximum from dataset. @@ -192,8 +221,21 @@ def drop_long_seq_in_dataset( if filter_map_kwargs: drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + if excess_length_strategy == "truncate": + process_fn = functools.partial( + truncate_long_seq, + sequence_len=sequence_len, + min_sequence_len=cfg.min_sample_len, + ) + drop_long_kwargs["desc"] = ( + f"Truncating/Filtering Sequences (target_len={sequence_len})" + ) + else: + process_fn = drop_long + dataset = dataset.filter( - drop_long, + process_fn, batched=True, **filter_map_kwargs, **drop_long_kwargs, @@ -201,6 +243,11 @@ def drop_long_seq_in_dataset( if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + action = ( + "truncated/filtered" + if excess_length_strategy == "truncate" + else "dropped" + ) + LOG.warning(f"{action.title()} {dropped} samples from dataset") return dataset diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9cd98a4b2..b97944b48 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -420,6 +420,12 @@ class AxolotlInputConfig( "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" }, ) + excess_length_strategy: Literal["drop", "truncate"] | None = Field( + default=None, + json_schema_extra={ + "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility." + }, + ) eval_sequence_len: int | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 0d6d05a0e..217244b01 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -370,10 +370,10 @@ class TrainingValidationMixin: "see speed improvements. Please consider setting `torch_compile: " "true` in your config." ) + fsdp_config = data.get("fsdp_config") or {} if data.get("fp8") and ( - data.get("fsdp_config", {}).get("activation_checkpointing", False) is True - or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) - is True + fsdp_config.get("activation_checkpointing", False) is True + or fsdp_config.get("fsdp_activation_checkpointing", False) is True ): LOG.warning( "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " @@ -818,13 +818,13 @@ class OptimizationValidationMixin: @model_validator(mode="before") @classmethod def check_fsdp_version_in_fsdp_config(cls, data): - if data.get("fsdp_config"): - if data.get("fsdp_config", {}).get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") return data @model_validator(mode="before") @@ -1152,10 +1152,8 @@ class ModelCompatibilityValidationMixin: @classmethod def check_gpt_oss_fsdp_loading(cls, data): if data.get("model_quantization_config", "") == "Mxfp4Config": - if ( - data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) - is True - ): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config.get("cpu_ram_efficient_loading", False) is True: raise ValueError( "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." ) diff --git a/src/axolotl/utils/train.py b/src/axolotl/utils/train.py new file mode 100644 index 000000000..1393459d9 --- /dev/null +++ b/src/axolotl/utils/train.py @@ -0,0 +1,45 @@ +"""Training utils for checkpoints""" + +from pathlib import Path + +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None: + """ + Determine the checkpoint to resume from based on configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + update: Whether to update the config with the determined checkpoint + + Returns: + Path to the checkpoint to resume from, or `None` if not resuming. + """ + last_checkpoint = None + checkpoints = sorted( + ( + p + for p in Path(cfg.output_dir).glob("checkpoint-*") + if p.name.split("-")[-1].isdigit() + ), + key=lambda p: int(p.name.split("-")[-1]), + ) + if checkpoints: + last_checkpoint = str(checkpoints[-1]) + if not update: + return last_checkpoint + + if ( + cfg.resume_from_checkpoint is None + and cfg.auto_resume_from_checkpoints + and last_checkpoint is not None + ): + cfg.resume_from_checkpoint = last_checkpoint + LOG.info( + f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + ) + return cfg.resume_from_checkpoint diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5931fe148..939ed5c1c 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -147,7 +147,11 @@ def require_hopper(test_case): def check_tensorboard( - temp_run_dir: str, tag: str, lt_val: float, assertion_err: str + temp_run_dir: str, + tag: str, + lt_val: float, + assertion_err: str, + rtol: float = 0.02, ) -> None: """ helper function to parse and check tensorboard logs @@ -157,6 +161,7 @@ def check_tensorboard( reader = SummaryReader(event_file) df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name + lt_val = (1 + rtol) * lt_val if "%s" in assertion_err: assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] else: diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 7cb645db7..47894a35b 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.completion import load from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.data.utils import drop_long_seq_in_dataset +from axolotl.utils.data.utils import handle_long_seq_in_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -70,7 +70,7 @@ class TestBatchedSamplerPacking: ) train_dataset = concatenate_datasets([dataset_wrapper]) - train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) + train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( diff --git a/tests/utils/test_train.py b/tests/utils/test_train.py new file mode 100644 index 000000000..a1f6f6088 --- /dev/null +++ b/tests/utils/test_train.py @@ -0,0 +1,24 @@ +"""test for train checkpoint utils""" + +import os + +from axolotl.utils.dict import DictDefault +from axolotl.utils.train import determine_last_checkpoint + + +def test_determine_last_checkpoint(temp_dir): + cfg = DictDefault( + output_dir=temp_dir, + ) + for cpt_idx in [1, 9, 10, 20]: + os.makedirs( + os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True + ) + + last_checkpoint = determine_last_checkpoint(cfg, update=False) + assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20") + + cfg.resume_from_checkpoint = None + cfg.auto_resume_from_checkpoints = True + determine_last_checkpoint(cfg, update=True) + assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")