Commit Graph

8 Commits

Author SHA1 Message Date
Dan Saunders
b5f1e53a0f models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
2025-05-23 15:51:11 -04:00
NanoCode012
798b5f5cfd fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls

* feat(test): add test to catch trainer_cls is not None
2025-05-22 19:19:12 +07:00
salman
ac471a697a updating to fused (#2293) 2025-01-30 11:45:56 -05:00
Wing Lian
ce5bcff750 various tests fixes for flakey tests (#2110)
* add mhenrichsen/alpaca_2k_test with revision dataset download fixture for flaky tests

* log slowest tests

* pin pynvml==11.5.3

* fix load local hub path

* optimize for speed w smaller models and val_set_size

* replace pynvml

* make the resume from checkpoint e2e faster

* make tests smaller
2024-12-02 17:28:58 -05:00
Wing Lian
7d1d22f72f ORPO Trainer replacement (#1551)
* WIP use trl ORPOTrainer

* fixes to make orpo work with trl

* fix the chat template laoding

* make sure to handle the special tokens and add_generation for assistant turn too
2024-04-19 17:25:36 -04:00
NanoCode012
ff939d8a64 fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)
* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path

* fix: normalize config
2024-03-25 15:34:54 +09:00
Wing Lian
78c5b1979e add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083) 2024-01-10 22:32:43 -05:00
Wing Lian
f243c2186d RL/DPO (#935)
* ipo-dpo trainer

* fix missing abstract method

* chatml template, grad checkpointing kwargs support

* fix steps calc for RL and add dataloader kwargs

* wip to fix dpo and start ppo

* more fixes

* refactor to generalize map fn

* fix dataset loop and handle argilla pref dataset

* set training args

* load reference model on seperate gpu if more than one device

* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo

* fixes for rl training

* support for ipo from yaml

* set dpo training args from the config, add tests

* chore: lint

* set sequence_len for model in test

* add RLHF docs
2024-01-04 18:22:55 -05:00