Commit Graph

753 Commits

Author SHA1 Message Date
Wing Lian
8619b2d855 add torch_compile_mode options (#1763) [skip ci]
* add torch_compile_mode options

* make sure n_gpu is an int
2024-07-17 15:38:07 -04:00
Wing Lian
976f85195a fixes to accelerator so that iterable pretraining datasets work (#1759)
* fixes to accelerator so that iterable pretraining datasets work

* fix the pretraining test params

* split batches, not dispatch batches needs to be set

* update c4 datasets

* set epochs in pretrain config test

* need to set both split_batches and dispatch_batches to false for pretraining

* fix bool val in comment
2024-07-17 10:58:38 -04:00
Wing Lian
152ab76623 fix num gpu check (#1760) 2024-07-17 10:58:14 -04:00
Wing Lian
5f58555bd0 support for llama multipack using updated code/patches (#1754)
* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
2024-07-16 17:36:29 -04:00
Wing Lian
cfc533a7f7 torch compile and cuda alloc improvements (#1755)
* enable experimental expandable_segments

* hf trainer seems to be missing torch compile

* disable PYTORCH_CUDA_ALLOC_CONF to see if that fixes cicd
2024-07-16 16:00:23 -04:00
Wing Lian
78e12f8ca5 add basic support for the optimi adamw optimizer (#1727)
* add support for optimi_adamw optimizer w kahan summation

* pydantic validator for optimi_adamw

* workaround for setting optimizer for fsdp

* make sure to install optimizer packages

* make sure to have parity for model parameters passed to optimizer

* add smoke test for optimi_adamw optimizer

* don't use foreach optimi by default
2024-07-14 19:12:57 -04:00
Wing Lian
98af5388ba bump flash attention 2.5.8 -> 2.6.1 (#1738)
* bump flash attention 2.5.8 -> 2.6.1

* use triton implementation of cross entropy from flash attn

* add smoke test for flash attn cross entropy patch

* fix args to xentropy.apply

* handle tuple from triton loss fn

* ensure the patch tests run independently

* use the wrapper already built into flash attn for cross entropy

* mark pytest as forked for patches

* use pytest xdist instead of forked, since cuda doesn't like forking

* limit to 1 process and use dist loadfile for pytest

* change up pytest for fixture to reload transformers w monkeypathc
2024-07-14 19:11:31 -04:00
Wing Lian
a4a5bf057f fixes to prevent vram spike when train starts (#1742) 2024-07-13 09:53:13 -04:00
Wing Lian
47e1916484 add tests so CI can catch updates where patches will break with unsloth (#1737) [skip ci] 2024-07-11 16:43:19 -04:00
mhenrichsen
1194c2e0b1 github urls (#1734)
Co-authored-by: Henrichsen, Mads (ext) <mads.henrichsen.ext@siemens-energy.com>
2024-07-11 09:19:29 -04:00
Wing Lian
a159724e44 bump trl and accelerate for latest releases (#1730)
* bump trl and accelerate for latest releases

* ensure that the CI runs on new gh org

* drop kto_pair support since removed upstream
2024-07-10 11:15:44 -04:00
Josh Bleecher Snyder
b3f680d305 sanity check ranges in freeze.py (#1686)
* sanity check ranges in freeze.py

this will catch problems earlier and more clearly.

in my case, it appears that deepspeed zero3 sets layer tensor shapes
to [0], which doesn't play well with automatically inferred ranges.
through a bit of luck, inverting ranges still appears to work correctly.

* simplify chained comparison
2024-07-05 09:24:07 -04:00
Wing Lian
c69b7eb2b5 full weights fsdp training seems broken with fsdp_cpu_ram_efficient_loading, disabling for now (#1726) 2024-07-05 09:15:36 -04:00
Wing Lian
c6d83a87c4 add support for .env files for env vars (#1724) 2024-07-02 13:17:40 -04:00
Wing Lian
5370cedf0c support for gemma2 w sample packing (#1718) 2024-06-29 01:38:55 -04:00
DavidFarago
559562d790 Allow "weight: 0" in messages to mask them (#1703)
Allow in message objects the additional key `weight`, which can be set
to 0 (or 1) to cause that message to be masked out (or left unmasked)
for training (similar to [1]). This is helpful for training the model to be robust and
capable of error recovery upon a bad assistant message.
A missing `weight` key defaults to weight 1, to guarantee downward compatibility.

[1]: https://github.com/mistralai/mistral-finetune
2024-06-20 10:05:16 -04:00
Wing Lian
4de4b4089f add support for multipack for deepseek_v2 (#1712) 2024-06-20 10:02:55 -04:00
Wing Lian
3f1f5e3312 drop length column for issues with eval without packing (#1711) 2024-06-18 23:32:29 -04:00
Wing Lian
5783839c6e download model weights on preprocess step (#1693) 2024-06-09 20:10:17 -04:00
Wing Lian
cbbf039a46 verbose failure message (#1694) 2024-06-09 20:09:36 -04:00
Wing Lian
18cabc0c46 fix for when sample_packing and eval_sample_packing are different (#1695) 2024-06-08 09:48:30 -04:00
Wing Lian
ed8ef65371 add back packing efficiency estimate so epochs and multi-gpu works properly (#1697) 2024-06-08 09:48:10 -04:00
Wing Lian
9c1af1a9c0 ensure explicit eval_sample_packing to avoid mismatch issues (#1692) 2024-06-07 11:28:43 -04:00
Brian Fitzgerald
cf64284a04 Phi-3 conversation format, example training script and perplexity metric (#1582)
* phi-3 support and perplexity metric

* phi-3 chat template

* metrics updates

* chore: lint

* fix assertion on Tensor

* fix tests since tokenization happens in the metric

* fix perplexity value of shorter passage

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-06-04 16:11:56 -04:00
Wing Lian
c996881ec2 add support for rpo_alpha (#1681)
* add support for rpo_alpha

* Add smoke test for dpo + nll loss
2024-06-04 16:09:51 -04:00
Wing Lian
1f151c0d52 re-enable DPO for tests in modal ci (#1374)
* re-enable DPO for tests in modal ci

* workaround for training args

* don't mixin AxolotlTrainingArguments

* fix mixin order so MRO doesn't result in

 TypeError: non-default argument follows default argument error

* use smaller datasets for dpo tests
2024-06-03 12:50:44 -04:00
Wing Lian
05b0bd08d2 need to add back drop_last for sampler (#1676) 2024-05-31 13:13:13 -04:00
Wing Lian
d4f6c65e4c cleanup the deepspeed proxy model at the end of training (#1675) 2024-05-30 13:40:35 -04:00
Wing Lian
a944f7b32b load explicit splits on datasets (#1652) 2024-05-29 22:27:59 -04:00
Wing Lian
9d4225a058 set chat_template in datasets config automatically (#1664)
* set chat_template in datasets config automatically

* dynamic chat_template, not jsut chatml
2024-05-29 22:27:26 -04:00
Wing Lian
f7332ac449 use mixins for orpo and kto configs so they work with axolotl customizations (#1674) 2024-05-29 22:27:00 -04:00
Wing Lian
a6b37bdeb4 revert multipack batch sampler changes (#1672)
* revert multipack batch sampler changes

* fix default val for drop_last
2024-05-29 11:51:18 -04:00
Wing Lian
b7520801a3 handle the system role too for chat templates (#1671) 2024-05-29 10:21:11 -04:00
Wing Lian
fe650dd326 make sure the CI fails when pytest script fails (#1669)
* make sure the pytest script fails

* make sure the defaults come through for tests

* make sure tensorboard is loaded for test assertion
2024-05-29 10:12:11 -04:00
Seungduk Kim
65db903714 Correct name of MixtralBlockSparseTop2MLP (L -> l) (#1667) 2024-05-28 18:10:29 -04:00
Davide Caroselli
6a5a725f10 Fix: ensure correct handling of val_set_size as float or int (#1655)
* Fix: ensure correct handling of val_set_size as float or int

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-28 12:00:32 -04:00
Keith Stevens
cc11c6bce2 Generalizing the chat_template prompt strategy (#1660) [skip ci]
The strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types

additionally this adds a sample llama3-8b instruct template using the chat template
2024-05-28 11:24:13 -04:00
Wing Lian
367b2e879b Switch to parallel FFD bin packing algorithm. (#1619)
* Switch to parallel FFD bin packing algorithm.

Add support for packing in a distributed context.
Add packing efficiency estimate back.

* revert changes to distributed code

* chore: lint

* fix config w new params for packing test

* add sample_packing_group_size and sample_packing_bin_size to cfg schema

* fix lamdbda function

* fix sampler/dataloader calculations for packing

---------

Co-authored-by: dsesclei <dave@sescleifer.com>
2024-05-23 17:32:14 -04:00
Wing Lian
bbfed318bc support for custom messages field in sharegpt (#1651) 2024-05-23 13:03:22 -04:00
George Grigorev
a27d5e1f4e enable loraplus setting for dpo trainer (#1646) 2024-05-22 08:29:06 -04:00
Wing Lian
6299eb5919 allow report_to for multiple providers (#1647) 2024-05-22 08:27:44 -04:00
Leonard
7c2bf3091f Fix llama3 chat_template (extra <|eot_id|> on last turn) (#1635)
* Fix llama3 chat_template (the {{eos_token}} leads to an extra <|eot_id|> being added in the last turn). Output now matches official Llama 3 Instruct model

* add tests

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-21 09:08:53 -04:00
Ben Redmond
22ae21a6c2 Add KTO support (#1640)
* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-20 16:05:16 -04:00
Wing Lian
ba45531802 fixes to save on fractional save_steps (#1643) 2024-05-20 14:24:45 -04:00
Wing Lian
8a1572a831 Unsloth optims for Llama (#1609)
* WIP for unsloth integrations

* import the unsloth code in the right context

* add unsloth mlp, qkv, o lora optimizations

* apply unsloth mlp and qkv kernels
2024-05-20 09:55:06 -04:00
Jeffrey Quesnelle
702a669cad add save_only_model option (#1634) 2024-05-17 00:23:18 -04:00
bofeng huang
81da7d2531 Fix total_num_steps (#1566)
* Fix `total_num_steps`

* Fix total_num_steps

* lint
2024-05-14 20:10:37 -04:00
Ali Mosavian
1e1921b794 FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584)
* FIX: TRL trainer preprocessing step was running in one process

* FIX: max_length and max_prompt_length was not being sent to ORPOTrainer

* FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour

* FIX: Removed change from a different PR

* FIX: Black fix

* explicitly set max prompt len for orpo config

---------

Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-14 08:51:17 -04:00
Wing Lian
1634ac82e0 make sure to save on the last step (#1615) 2024-05-14 08:48:39 -04:00
Wing Lian
02982733ec fix attention mask collation (#1603) 2024-05-14 08:17:30 -04:00