feat: add cut_cross_entropy (#2091)
* feat: add cut_cross_entropy * fix: add to input * fix: remove from setup.py * feat: refactor into an integration * chore: ignore lint * feat: add test for cce * fix: set max_steps for liger test * chore: Update base model following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: update special_tokens following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: remove with_temp_dir following comments * fix: plugins aren't loaded * chore: update quotes in error message * chore: lint * chore: lint * feat: enable FA on test * chore: refactor get_pytorch_version * fix: lock cce commit version * fix: remove subclassing UT * fix: downcast even if not using FA and config check * feat: add test to check different attentions * feat: add install to CI * chore: refactor to use parametrize for attention * fix: pytest not detecting test * feat: handle torch lower than 2.4 * fix args/kwargs to match docs * use release version cut-cross-entropy==24.11.4 * fix quotes * fix: use named params for clarity for modal builder * fix: handle install from pip * fix: test check only top level module install * fix: re-add import check * uninstall existing version if no transformers submodule in cce * more dataset fixtures into the cache --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -27,7 +27,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import (
|
||||
@@ -38,6 +37,7 @@ from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
prepare_plugins,
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
||||
@@ -426,11 +426,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
@@ -449,6 +444,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
},
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
prepare_opinionated_env(cfg)
|
||||
|
||||
Reference in New Issue
Block a user