bump trl and accelerate for latest releases (#1730)

* bump trl and accelerate for latest releases

* ensure that the CI runs on new gh org

* drop kto_pair support since removed upstream
This commit is contained in:
Wing Lian
2024-07-10 11:15:44 -04:00
committed by GitHub
parent b3f680d305
commit a159724e44
11 changed files with 15 additions and 21 deletions

View File

@@ -5,7 +5,7 @@ on:
jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
strategy:

View File

@@ -8,7 +8,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -70,7 +70,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -128,7 +128,7 @@ jobs:
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:

View File

@@ -7,7 +7,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -70,7 +70,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:

View File

@@ -58,7 +58,7 @@ jobs:
pytest --ignore=tests/e2e/ tests/
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -4,7 +4,7 @@ peft==0.11.1
transformers==4.42.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.30.1
accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
pydantic==2.6.3
addict
@@ -40,6 +40,6 @@ s3fs
gcsfs
# adlfs
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
trl==0.9.6
zstandard==0.22.0
fastcore

View File

@@ -1670,8 +1670,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
@@ -1680,7 +1678,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
@@ -1695,7 +1693,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl == "kto":
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
else:

View File

@@ -165,7 +165,6 @@ class RLType(str, Enum):
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name

View File

@@ -805,11 +805,7 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
and not cfg.merge_lora
):
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora:
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)

View File

@@ -427,7 +427,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -115,6 +115,7 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
# pylint: disable=duplicate-code