diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index d51e2dd99..71edc8d0c 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -24,5 +24,5 @@ if cce_spec: print( UNINSTALL_PREFIX - + 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"' + + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 14449a144..b166a3004 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do python scripts/cutcrossentropy_install.py | sh # if you are not in dev environment -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"' +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c" ``` ## Usage diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 97517bccd..516e9a2ae 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") _CCE_INSTALL_MESSAGE = ( "Please install cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers]==24.11.4"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`' ) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index f65d65ee4..25e36b5eb 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -69,6 +69,51 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + # pylint: disable=redefined-outer-name + def test_qwen2_w_cce(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "plugins": [ + "axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin", + ], + "cut_cross_entropy": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "output_dir": temp_dir, + "lr_scheduler": "cosine", + "save_safetensors": True, + "max_steps": 10, + "bf16": "auto", + } + ) + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + major, minor, _ = get_pytorch_version() + if (major, minor) < (2, 4): + with pytest.raises(ImportError): + train(cfg=cfg, dataset_meta=dataset_meta) + else: + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + @pytest.mark.parametrize( "attention_type", [