Commit Graph

1876 Commits

Author SHA1 Message Date
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
Wing Lian
7232cbdeab chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
e8fceb7091 chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
a5e0671738 make sure to use tensorboard to capture loss for checks 2025-01-14 22:47:48 -05:00
Wing Lian
b9847553af fix adapter model check 2025-01-14 22:47:48 -05:00
Wing Lian
513ec9e03b make sure to use the correct tokenizer 2025-01-14 22:47:48 -05:00
Wing Lian
530347856d make sure to set tokenizer from l3 70b and save safetensors 2025-01-14 22:47:47 -05:00
Wing Lian
261e4fb619 lower lr 2025-01-14 22:47:47 -05:00
Wing Lian
158071e95f set lora_dropout explicitly 2025-01-14 22:47:47 -05:00
Wing Lian
432f65f5e6 make the kd e2e fit in vram for ci and add lora version 2025-01-14 22:47:47 -05:00
Wing Lian
1d039f5486 rename test files so it gets picked up 2025-01-14 22:47:47 -05:00
Wing Lian
b9a42b396f linting 2025-01-14 22:47:47 -05:00
Wing Lian
ff2fb0fc1b add kd trainer e2e test 2025-01-14 22:47:47 -05:00
Wing Lian
317f290186 reward model doesn't work well with batched 2025-01-14 22:47:46 -05:00
Wing Lian
ab690f3f01 improve check for batched 2025-01-14 22:47:46 -05:00
Wing Lian
47932f21c4 fix reward trainer calls for tokenization 2025-01-14 22:47:46 -05:00
Wing Lian
808328e041 reward can use same batch check 2025-01-14 22:47:46 -05:00
Wing Lian
6784822cfb tweak check for batched prompt data 2025-01-14 22:47:46 -05:00
Wing Lian
684b38291f ensure that batch vs single is done properly 2025-01-14 22:47:46 -05:00
Wing Lian
01896b1bde improve iterable support 2025-01-14 22:47:46 -05:00
Wing Lian
e659c01646 support streaming for processing sft datasts? 2025-01-14 22:47:45 -05:00
Wing Lian
204d6c43b4 make loss torch script compat 2025-01-14 22:47:45 -05:00
Wing Lian
d3c2b7ce9d kd sample packing 2025-01-14 22:47:45 -05:00
Wing Lian
93dfff92f1 be a bit pickier about loading dynamic prompt strategies 2025-01-14 22:47:45 -05:00
Wing Lian
6e409d2d88 more info on preprocess for kd and fix import 2025-01-14 22:47:45 -05:00
Wing Lian
d5bc214300 remove duplicate code 2025-01-14 22:47:45 -05:00
Wing Lian
92c6c1087e add copyrights 2025-01-14 22:47:45 -05:00
Wing Lian
feed96f95e increase logging around loading plugins 2025-01-14 22:47:44 -05:00
Wing Lian
cba6165ae1 make plugin setup concise 2025-01-14 22:47:44 -05:00
Wing Lian
cdfcd69afa remove moved class from import 2025-01-14 22:47:44 -05:00
Wing Lian
885653d52e move more things to kd plugin 2025-01-14 22:47:44 -05:00
Wing Lian
27faacbf5a refactor kd chat template loader 2025-01-14 22:47:44 -05:00
Wing Lian
c51b0337c1 support for custom trainer classes from plugins 2025-01-14 22:47:44 -05:00
Wing Lian
fa055f9f69 handle token/logprob shifting 2025-01-14 22:47:43 -05:00
Wing Lian
f60c623af0 remove references to triton kd for now 2025-01-14 22:47:43 -05:00
Wing Lian
746891eb5c add license block 2025-01-14 22:47:43 -05:00
Wing Lian
f09b5da60b refactor so we can easily add new loss functions 2025-01-14 22:47:43 -05:00
Wing Lian
689e1c10ba chore: lint 2025-01-14 22:47:43 -05:00
Wing Lian
a5c085e003 var naming and add todo 2025-01-14 22:47:43 -05:00
Wing Lian
63146300b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-14 22:47:43 -05:00
Wing Lian
ca5e397fc5 use kd_alpha in the correct loss method 2025-01-14 22:47:42 -05:00
Wing Lian
3416302b0d hash for temperature too 2025-01-14 22:47:42 -05:00
Wing Lian
7366efc4ca better rescaling for temperatures 2025-01-14 22:47:42 -05:00
Wing Lian
d8d817eaed don't use triton for now 2025-01-14 22:47:42 -05:00
Wing Lian
c0757e8a20 fix kwarg 2025-01-14 22:47:42 -05:00
Wing Lian
e565694914 v3 2025-01-14 22:47:42 -05:00
Wing Lian
081928e55b no torch.tensor 2025-01-14 22:47:42 -05:00
Wing Lian
dc90c93894 no log etc 2025-01-14 22:47:41 -05:00
Wing Lian
18a46c338a no torch.exp inside triton kernel 2025-01-14 22:47:41 -05:00