Commit Graph

27 Commits

Author SHA1 Message Date
Wing Lian
2ea70ebbd8 ORPO (#1419)
* orpo trainer

* rl handling for orpo

* support for remove_unused_columns

* orpo fixes

* fix loader for orpo

* chore: lint

* fix default for remove_unused_columns

* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora

* better handling of system message for orpo

* revert system prompt changes for chat templtes

* no need for else condition

* split dataset parsing into it's own component
2024-03-18 13:10:00 -04:00
NanoCode012
d485a08393 chore(script): remove redundant setting (#1411) 2024-03-16 21:10:38 +09:00
Seungduk Kim
05bcc9ea56 Train parameters exclusively in specific ranges (#1390)
* Train parameters exclusively in specific ranges

* Fix the style and update docs

* Update yaml example
2024-03-14 11:05:42 -04:00
Wing Lian
ea00dd0852 don't use load and push together (#1284) 2024-02-09 14:54:31 -05:00
Wing Lian
00568c1539 support for true batches with multipack (#1230)
* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
2024-02-01 10:18:42 -05:00
Wing Lian
c67fb71583 Peft deepspeed resume (#1227)
* import deepspeed integration

* monkeypatch peft adapater with deepspeed for resume from checkpoint

* fix patch

* fix patches attempt 2

* make sure to set lora_model_dir

* skip pylint for deepspeed.utils

* pick up upstream fix in transformers

* remove monkeypatch for deepspeed/peft fix

* no need to set the lora_model_dir on resume

* unset load_in_*bit when using quant config

* guard before del

* better handling of load_in* kwargs
2024-01-31 18:13:29 -05:00
Wing Lian
ba944e6554 workaround for transformers bug requireing do_sample for saveing pretrained (#1206) 2024-01-25 11:34:41 -05:00
Wing Lian
54d2ac155b Mixtral fixes 20240124 (#1192) [skip ci]
* mixtral nccl fixes

* make sure to patch for z3
2024-01-24 14:59:57 -05:00
Wing Lian
7523d1f557 DPO cleanup (#1126)
* cleanup dpo to be a little more extensible, add zephyr/nectar strategy

* fix eos slash

* support for eval split

* fix kwargs

* handle empty evals

* don't load peft model for dpo

* ensure dpo traning args gets bf16 for peft if applicable

* fix duplicate kwargs for bf16

* make sure to respect the configured lr scheduler

* supprt trainer callback to push config to wandb

* set dataloader preload args

* ensure that we are loading the lora when merging

* Update src/axolotl/utils/data.py

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* support local datasets for dpo

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* chore: lint

* dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names

* add split to dpo tests

* fix rebase/merging error

* handle edge case w logging

* use accelerator for dpo datasets so it doesn't break the logger

* missing args

* validate checkpoint is an adapter for now

* log warning when dataset strategy is not loadable

---------

Co-authored-by: Agus <agustin.piqueres@gmail.com>
2024-01-23 00:40:37 -05:00
Wing Lian
da97285e63 keep gate in fp32 for 16 bit loras (#1105)
* keep gate in fp32 for loras

* add e2e check for lora w/o flash attention for mixtral to check gate

* add checks for gate in fp32 for mixtral, add typehints to train outputs

* mixtral doesn't support basic lora 🤦

add lora tests @ 16bit and fix gate layer check
fix the parameter name, was using the old disco name
don't lora over the gate so we can check that is in fp32
fix dtype check

* ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8
2024-01-12 14:58:21 -05:00
NanoCode012
b432889256 feat: enable trl's autounwrap (#1060)
* feat: test trl's autounwrap

* fix: add check for adapter

* feat: add config to disable autounwrap

* chore: fix lint
2024-01-11 08:43:41 -05:00
Hamel Husain
31d23504a5 fix model card upload for PEFT models (#1043) 2024-01-04 18:13:54 -08:00
Wing Lian
f243c2186d RL/DPO (#935)
* ipo-dpo trainer

* fix missing abstract method

* chatml template, grad checkpointing kwargs support

* fix steps calc for RL and add dataloader kwargs

* wip to fix dpo and start ppo

* more fixes

* refactor to generalize map fn

* fix dataset loop and handle argilla pref dataset

* set training args

* load reference model on seperate gpu if more than one device

* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo

* fixes for rl training

* support for ipo from yaml

* set dpo training args from the config, add tests

* chore: lint

* set sequence_len for model in test

* add RLHF docs
2024-01-04 18:22:55 -05:00
Hamel Husain
85dd4d525b add config to model card (#1005)
* add config to model card

* rm space

* apply black formatting

* apply black formatting

* fix formatting

* check for cfg attribute

* add version

* add version

* put the config in a collapsible element

* put the config in a collapsible element
2023-12-27 21:25:33 -06:00
kallewoof
ef24342538 fix: switch to using the HuggingFace Transformers NEFT implementation (#941)
* fix: switch to using the HuggingFace Transformers NEFT implementation

* linter

* add support for noisy_embedding_alpha with a warning about it being renamed

* restore pre/posttrain_hooks

* move validation of NEFT noise alpha into validate_config()

* linter
2023-12-13 17:15:34 -05:00
Wing Lian
5ea3aa31f0 Fix Deepspeed loading (#950)
* add check for zero3

* freeze parameters

* fixes for deepspeed loading

* fix model parameter check

* unfrozen parameters in example mixtral and logging when unfreezing
2023-12-13 16:03:23 -05:00
Wing Lian
40a6362c92 support for mamba (#915)
* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg
2023-12-09 12:10:41 -05:00
Wing Lian
b2430ce670 use accelerate logging for zero/main loggin only 2023-11-06 18:32:26 -05:00
Wing Lian
4c834bf25d cleanup verbosity a bit 2023-11-06 18:32:26 -05:00
Wing Lian
827ec3d274 refactor neft patch to be more re-usable similar to trl's impl (#796) 2023-10-29 04:33:13 -04:00
Casper
15d3a654bf Implement fused modules (#747)
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-21 16:08:25 -04:00
Motoki Wu
e4d1585c4e Fix DeepSpeed Zero 3 Saving (#709)
* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-19 19:18:24 -04:00
Wing Lian
501958bb6f create a model card with axolotl badge (#624) 2023-09-22 16:13:26 -04:00
Jan Philipp Harries
be75668400 set fsdp state dict (#584)
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-09-15 17:47:36 -04:00
Wing Lian
a4e1bb6606 let hf trainer handle torch compile (#516)
* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-09-13 11:42:12 -04:00
Wing Lian
a546ca2813 misc fixes/improvements (#513)
fix per pr feedback
2023-09-05 16:40:13 -04:00
Wing Lian
b21e4a20fe split train from other cli options (#503) 2023-08-30 22:01:47 -07:00