Commit Graph

5 Commits

Author SHA1 Message Date
Wing Lian
00568c1539 support for true batches with multipack (#1230)
* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
2024-02-01 10:18:42 -05:00
Wing Lian
797f3dd1de don't train if eval split is too small (#873)
* allow zero len dataset

* better handling and warning of small eval splits

* raise error if eval split is too small

* don't mess with calculating total num steps in distributed context

* fix eval_sample_packing training args logic
2023-11-16 11:35:42 -05:00
Wing Lian
0c2a630326 multipack len should use max, not min (#863) 2023-11-15 12:52:32 -05:00
Wing Lian
14706504e3 various bugfixes (#856)
* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs
2023-11-15 12:23:18 -05:00
Wing Lian
641e6f7e51 multipack w batch sampler (#795)
* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg
2023-11-07 20:27:40 -05:00