f3c8a25b300ea9732edbd95888ea3d710e5ec37d
8 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
ccc94da8ad |
KD fix w/ online distillation (#2700) [skip ci]
* kd fixes * fix collator setup * fix input args * better handling to drop string fields for kd with raw dataset * kd trainer has kd temp as part of the init * drop top_k before softmax * simplfy and remove zscore * WIP chunked KD loss with autograd wrapper * more fixes and liger-type chunked loss * collator cls for plugins * remove debugging * additional plugin collator kwargs, don't scale up kd loss by t^2 * don't need temp arg to distill method * online kd wip * add close to comment block * suport sampling params/max new tokens * handle when no custom collator is used in plugins * logsumexp trick: * fix check * shift off the first empty token * fix length of padding * use max not min * temp scale kd loss at end * support for dynamic plugin training args mixins and symmetric kl * chore: lint * fix trainer callback base class * Fix decay * accept compressed responses for smaller wire payload * post-rebase lint * more KD updates * increase hyperparams_count for gradients for added normalize_topk * fix to remove attention_mask * rename vars for consistency * fix rebase issues * default to dropping last batch in multipack batch sampler * improve handling of train len * init collator_cls_and_kwargs * explicit drop_last=False when checking for multipack completeness * use separate v2 loader for kd * fix kd tests to use subprocess so it picks up kd training args * default value for kd_beta arg * use updated dataset for ci * longer timeout for e2e |
||
|
|
00cda8cc70 |
Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test |
||
|
|
c67910fa6f |
bump hf deps (#2735) [skip ci]
* bump hf deps * upgrade liger-kernel too * install cce from fork for transformers fix * fix reference to vocab size in gemma3 patch * use padding_idx instead of pad_token_id * remove fixed gemma3 patch * use updated cce fork * fix local mllama cce patches w docstring * add test for multipack with trainer setup and fix trainer for trainer refactor upstream * bump modal version * guard for iterable datasetS * mllama model arch layout changed in latest transformers * fix batch sampler with drop_last * fix: address upstream vlm changes for lora * fix: update references to old lora target path * fix: remove mllama fa2 patch due to upstream fix * fix: lora kernel patch path for multimodal models * fix: removed mllama from quarto * run test for came optim on 2.6.0+ * fix fsdp2 patch and remove deprecated patch * make sure to set sequence_parallel_degree for grpo * Add SP test for GRPO * add sp to grpo config for trainer * use reward_funcs as kwarg to grpo trainer * fix the comprehension for reward funcs * reward funcs already passed in as args * init sp_group right before training * fix check for adding models to SP context * make sure to pass args to super * upgrade deepspeed * use updated trl and add reasoning flags for vllm * patch the worker --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> |
||
|
|
47e0e71bc8 |
don't sort multipack sampler (#2657)
* don't sort multipack sampler * increased packing efficiency increases loss --------- Co-authored-by: Wing Lian <wing@axolotl.ai> |
||
|
|
cf0c79d52e |
fix: minor patches for multimodal (#2441)
* fix: update chat_template * fix: handle gemma3 showing a lot of no content for turn 0 * fix: remove unknown config from examples * fix: test * fix: temporary disable gemma2 test * fix: stop overwriting config.text_config unnecessarily * fix: handling of set cache to the text_config section * feat: add liger gemma support and bump liger to 0.5.5 * fix: add double use_cache setting * fix: add support for final_logit_softcap in CCE for gemma2/3 * fix: set use_cache before model load * feat: add missing layernorm override * fix: handle gemma3 rmsnorm * fix: use wrapper to pass dim as hidden_size * fix: change dim to positional * fix: patch with wrong mlp * chore: refactor use_cache handling * fix import issues * fix tests.e2e.utils import --------- Co-authored-by: Wing Lian <wing@axolotl.ai> |
||
|
|
c907ac173e |
adding pre-commit auto-update GH action and bumping plugin versions (#2428)
* adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com> |
||
|
|
b194e17c28 |
feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai> |
||
|
|
78ce268848 |
KD Trainer w logprobs (#2303)
* refactor trainer to prevent circular dependencies later fix loader default KD dataset loading and KD with logprobs filter bad rows make batch smaller handle padding/collation for KD datasets make it work flipped the slice cross entropy loss coefficient during KD make sure to multiply against the correct loss chore: lint triton wip no where support v2 trial no torch.exp inside triton kernel no log etc no torch.tensor v3 fix kwarg don't use triton for now better rescaling for temperatures hash for temperature too use kd_alpha in the correct loss method fix kd loss so it's causal (fixes repeating tokens) var naming and add todo chore: lint refactor so we can easily add new loss functions add license block remove references to triton kd for now handle token/logprob shifting support for custom trainer classes from plugins refactor kd chat template loader move more things to kd plugin remove moved class from import make plugin setup concise increase logging around loading plugins add copyrights remove duplicate code more info on preprocess for kd and fix import be a bit pickier about loading dynamic prompt strategies kd sample packing make loss torch script compat support streaming for processing sft datasts? improve iterable support ensure that batch vs single is done properly tweak check for batched prompt data reward can use same batch check fix reward trainer calls for tokenization improve check for batched reward model doesn't work well with batched add kd trainer e2e test linting rename test files so it gets picked up make the kd e2e fit in vram for ci and add lora version set lora_dropout explicitly lower lr make sure to set tokenizer from l3 70b and save safetensors make sure to use the correct tokenizer fix adapter model check make sure to use tensorboard to capture loss for checks chore: lint chore: lint improve logprob masking and shift in trainer more fixes try tests for kd on l40s don't shift student logits for kd no batching for kd chat templates make sure to truncate logprobs if there are more than top_k change up logic so we always truncate to top_k use iter instead of tuple fix finding the top-k rather than assuming first position has the correct val apply z-score scaling to kd kd loss needs to be calculated in full precision Always re-normalize teacher distribution various fixes * support for configurable top-k/softmax ordering * add attribute check for filter rows and lint * fix logic * handle none case for conversion to int * fix student logit off by one * set kd_temp to 1.0 for test loss * address PR feedback |