Commit Graph

1858 Commits

Author SHA1 Message Date
Wing Lian
890d85f267 make the kd e2e fit in vram for ci and add lora version 2025-01-09 18:57:28 -05:00
Wing Lian
7dc137ed5b rename test files so it gets picked up 2025-01-09 18:57:28 -05:00
Wing Lian
a31ec4d9b3 linting 2025-01-09 18:57:28 -05:00
Wing Lian
7e7762f40b add kd trainer e2e test 2025-01-09 18:57:27 -05:00
Wing Lian
1ffca753ca reward model doesn't work well with batched 2025-01-09 18:57:27 -05:00
Wing Lian
01d31587fe improve check for batched 2025-01-09 18:57:27 -05:00
Wing Lian
9b7d3894c0 fix reward trainer calls for tokenization 2025-01-09 18:57:27 -05:00
Wing Lian
1baffa54b1 reward can use same batch check 2025-01-09 18:57:27 -05:00
Wing Lian
2045ff2b7a tweak check for batched prompt data 2025-01-09 18:57:27 -05:00
Wing Lian
93903f4aa5 ensure that batch vs single is done properly 2025-01-09 18:57:27 -05:00
Wing Lian
b5b3452b2b improve iterable support 2025-01-09 18:57:27 -05:00
Wing Lian
6bbe3ac641 support streaming for processing sft datasts? 2025-01-09 18:57:27 -05:00
Wing Lian
9ed455ef8c make loss torch script compat 2025-01-09 18:57:26 -05:00
Wing Lian
66823c113c kd sample packing 2025-01-09 18:57:26 -05:00
Wing Lian
e976de4d8f be a bit pickier about loading dynamic prompt strategies 2025-01-09 18:57:26 -05:00
Wing Lian
8eb82bba40 more info on preprocess for kd and fix import 2025-01-09 18:57:26 -05:00
Wing Lian
9fe36db215 remove duplicate code 2025-01-09 18:57:26 -05:00
Wing Lian
9dcc879e04 add copyrights 2025-01-09 18:57:26 -05:00
Wing Lian
1e577a29a8 increase logging around loading plugins 2025-01-09 18:57:26 -05:00
Wing Lian
4037fdb43a make plugin setup concise 2025-01-09 18:57:26 -05:00
Wing Lian
385c60cd9b remove moved class from import 2025-01-09 18:57:26 -05:00
Wing Lian
06370b386a move more things to kd plugin 2025-01-09 18:57:26 -05:00
Wing Lian
3da6a652fa refactor kd chat template loader 2025-01-09 18:57:25 -05:00
Wing Lian
84547c724d support for custom trainer classes from plugins 2025-01-09 18:57:25 -05:00
Wing Lian
51547c656a handle token/logprob shifting 2025-01-09 18:57:25 -05:00
Wing Lian
7c4ae15942 remove references to triton kd for now 2025-01-09 18:57:25 -05:00
Wing Lian
cdb167e7f7 add license block 2025-01-09 18:57:25 -05:00
Wing Lian
52f1d7aee2 refactor so we can easily add new loss functions 2025-01-09 18:57:25 -05:00
Wing Lian
319c3531e7 chore: lint 2025-01-09 18:57:25 -05:00
Wing Lian
87eb6a3324 var naming and add todo 2025-01-09 18:57:25 -05:00
Wing Lian
f03fa703b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-09 18:57:25 -05:00
Wing Lian
53ec07d44c use kd_alpha in the correct loss method 2025-01-09 18:57:25 -05:00
Wing Lian
8d77dc385e hash for temperature too 2025-01-09 18:57:24 -05:00
Wing Lian
8b0104fa7c better rescaling for temperatures 2025-01-09 18:57:24 -05:00
Wing Lian
546ad007ec don't use triton for now 2025-01-09 18:57:24 -05:00
Wing Lian
868a49cb96 fix kwarg 2025-01-09 18:57:24 -05:00
Wing Lian
4a12b1b22e v3 2025-01-09 18:57:24 -05:00
Wing Lian
973ed841cd no torch.tensor 2025-01-09 18:57:24 -05:00
Wing Lian
9c0470130b no log etc 2025-01-09 18:57:24 -05:00
Wing Lian
0da2b7c7cc no torch.exp inside triton kernel 2025-01-09 18:57:24 -05:00
Wing Lian
7c813a1d27 v2 trial 2025-01-09 18:57:24 -05:00
Wing Lian
0a08bb4f78 no where support 2025-01-09 18:57:24 -05:00
Wing Lian
8075a92a33 triton wip 2025-01-09 18:57:23 -05:00
Wing Lian
ba6eacd167 chore: lint 2025-01-09 18:57:23 -05:00
Wing Lian
e2fae47114 make sure to multiply against the correct loss 2025-01-09 18:57:23 -05:00
Wing Lian
7d281b71dc cross entropy loss coefficient during KD 2025-01-09 18:57:23 -05:00
Wing Lian
b080c53afc flipped the slice 2025-01-09 18:57:23 -05:00
Wing Lian
1ea225129f make it work 2025-01-09 18:57:23 -05:00
Wing Lian
e2aba41939 handle padding/collation for KD datasets 2025-01-09 18:57:23 -05:00
Wing Lian
21caaaa2e9 make batch smaller 2025-01-09 18:57:23 -05:00