Compare commits
15 Commits
enable_tp
...
djsaunde-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fae6b2df10 | ||
|
|
bd2a594b89 | ||
|
|
3798229d85 | ||
|
|
10cfecf02e | ||
|
|
339f3c67e2 | ||
|
|
d91feaffc8 | ||
|
|
e246ceffa4 | ||
|
|
8ddc18ec8d | ||
|
|
1c14c4a15c | ||
|
|
1f623e6cc8 | ||
|
|
f865464ae5 | ||
|
|
33090486d7 | ||
|
|
effc4dc409 | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 |
5
.github/workflows/tests-nightly.yml
vendored
5
.github/workflows/tests-nightly.yml
vendored
@@ -44,6 +44,11 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -5,6 +5,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -127,34 +127,40 @@ datasets:
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
|
||||
|
||||
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
|
||||
chat_template_jinja:
|
||||
# The key in the data example that contains the messages. Default is "messages".
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
# The key in the message turn that contains the role. Default is "role".
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# The key in the message turn that contains the content. Default is "content".
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
# Optional[Dict[str, List]]. Roles mapping for the messages.
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant", "ai"]
|
||||
assistant: ["gpt", "assistant"]
|
||||
system: ["system"]
|
||||
tool: ["tool"]
|
||||
|
||||
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["gpt", "assistant"]
|
||||
roles_to_train: ["assistant"] # default
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn: train on the EOS token at the end of each trainable turn
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
@@ -239,6 +245,9 @@ sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
@@ -331,7 +340,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
@@ -363,6 +373,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
|
||||
# snapshots can be visualized @ https://pytorch.org/memory_viz
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
|
||||
@@ -68,6 +68,8 @@ We recommend checking the below examples for other usecases.
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train:
|
||||
train_on_eos:
|
||||
```
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
@@ -77,7 +79,7 @@ chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
@@ -87,7 +89,6 @@ chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
@@ -99,7 +100,6 @@ chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: deepseek-ai/DeepSeek-V2-Lite
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
|
||||
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# use google/gemma-7b if you have access
|
||||
base_model: mhenrichsen/gemma-7b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: google/gemma-2-9b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: google/gemma-2-2b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForSequenceClassification
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: ai21labs/Jamba-v0.1
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: ai21labs/Jamba-v0.1
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: ai21labs/AI21-Jamba-1.5-Large
|
||||
# optionally might have model_type or tokenizer_type
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||
# optionally might have model_type or tokenizer_type or processor_type
|
||||
processor_type: AutoProcessor
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
tensor_parallel: 'auto'
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Meta-Llama-3-8B-Instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
tensor_parallel: 'auto'
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Meta-Llama-3-8B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
|
||||
# optionally might have model_type or tokenizer_type
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: casperhansen/llama-3-70b-fp16
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: NousResearch/Meta-Llama-3-8B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
base_model: state-spaces/mamba-2.8b
|
||||
# optionally might have model_type or tokenizer_type or tokenizer_config
|
||||
model_type: MambaLMHeadModel
|
||||
tokenizer_type: AutoTokenizer
|
||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -4,8 +4,11 @@
|
||||
#face problems with the special tokens.
|
||||
|
||||
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
base_model: mosaicml/mpt-7b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: microsoft/Phi-3.5-mini-instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: microsoft/phi-2
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: microsoft/Phi-3-mini-4k-instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
base_model: microsoft/Phi-3-mini-4k-instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
chat_template: phi_3
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: Qwen/Qwen1.5-MoE-A2.7B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: Qwen/Qwen1.5-MoE-A2.7B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
strict: false
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: Qwen/Qwen2-7B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code:
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
base_model: replit/replit-code-v1-3b
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: stabilityai/stablelm-2-1_6b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
base_model: stabilityai/stablelm-2-1_6b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
base_model: bigcode/starcoder2-3b
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: TinyLlama/TinyLlama_v1.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: TinyLlama/TinyLlama_v1.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: TinyLlama/TinyLlama_v1.1
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: 01-ai/Yi-34B-Chat
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -7,26 +7,31 @@ mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.4.2
|
||||
liger-kernel==0.5.2
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers>=4.46.3
|
||||
transformers==4.47.1
|
||||
tokenizers>=0.20.1
|
||||
accelerate==1.2.0
|
||||
accelerate==1.2.1
|
||||
datasets==3.1.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.12.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
@@ -36,7 +41,6 @@ scipy
|
||||
scikit-learn==1.4.2
|
||||
nvidia-ml-py==12.560.30
|
||||
art
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
|
||||
@@ -45,7 +49,6 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.1
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -55,5 +58,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.1b2
|
||||
|
||||
@@ -32,5 +32,5 @@ else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
||||
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||
)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Axolotl - Train and fine-tune large language models"""
|
||||
|
||||
import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.6.0"
|
||||
|
||||
52
src/axolotl/cli/evaluate.py
Normal file
52
src/axolotl/cli/evaluate.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
load_rl_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.evaluate import evaluate
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.evaluate")
|
||||
|
||||
|
||||
def do_evaluate(cfg, cli_args) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
do_evaluate(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -12,7 +12,8 @@ from axolotl.cli.utils import (
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
@@ -48,6 +49,9 @@ def train(config: str, accelerate: bool, **kwargs):
|
||||
"""Train or fine-tune a model."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||
if config:
|
||||
@@ -60,6 +64,31 @@ def train(config: str, accelerate: bool, **kwargs):
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
"""Evaluate a model."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
from axolotl.cli.evaluate import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
|
||||
@@ -15,6 +15,19 @@ configure_logging()
|
||||
LOG = logging.getLogger("axolotl.common.cli")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerCliArgs:
|
||||
"""
|
||||
@@ -31,16 +44,14 @@ class TrainerCliArgs:
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
class EvaluateCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
dataclass representing the various evaluation arguments
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
debug_num_examples: int = field(default=0)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -50,7 +61,9 @@ def load_model_and_tokenizer(
|
||||
):
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
inference = getattr(cli_args, "inference", False)
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""
|
||||
helper functions for fixing the embeddings/tokenizer
|
||||
"""
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
# GNU LESSER GENERAL PUBLIC LICENSE
|
||||
# Version 3, 29 June 2007
|
||||
#
|
||||
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
# Everyone is permitted to copy and distribute verbatim copies
|
||||
# of this license document, but changing it is not allowed.
|
||||
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
from collections import Counter
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
|
||||
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
|
||||
):
|
||||
"""
|
||||
Llama-3 for eg has untrained vectors in the base model.
|
||||
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
|
||||
We reset them to the mean of the rest of the tokens
|
||||
"""
|
||||
# Code licensed under LGPL
|
||||
embedding_matrix = model.get_input_embeddings().weight
|
||||
lm_head_matrix = model.get_output_embeddings().weight
|
||||
chat_template = getattr(tokenizer, "chat_template", None)
|
||||
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
|
||||
|
||||
# Ignore some model checks for now
|
||||
if not ignored_tokenizer_names:
|
||||
ignored_tokenizer_names = []
|
||||
if (
|
||||
model.config._name_or_path # pylint: disable=protected-access
|
||||
in ignored_tokenizer_names
|
||||
):
|
||||
return
|
||||
|
||||
# Sometimes the sizes can be different like in vision models
|
||||
# Ie <image> is in input, but not in output
|
||||
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
|
||||
embedding_matrix = embedding_matrix[:, :min_size]
|
||||
lm_head_matrix = lm_head_matrix[:, :min_size]
|
||||
|
||||
# Get untrained tokens
|
||||
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
|
||||
# Check lm_head as well
|
||||
|
||||
# Does NOT work for Llama 3.1!!
|
||||
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
|
||||
|
||||
# We instead check for repeated vectors
|
||||
lm_head_where = torch.where(indicator_untrained1)[0]
|
||||
lm_head_bad = lm_head_matrix[lm_head_where]
|
||||
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
|
||||
counter = Counter()
|
||||
for row in lm_head_bad:
|
||||
counter[hash(row.data.tobytes())] += 1
|
||||
counter = Counter({k: c for k, c in counter.items() if c >= 2})
|
||||
|
||||
lm_head_where = lm_head_where.cpu().numpy()
|
||||
final_bad_lm_head = []
|
||||
for j, row in enumerate(lm_head_bad):
|
||||
if hash(row.data.tobytes()) in counter:
|
||||
final_bad_lm_head.append(lm_head_where[j])
|
||||
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
|
||||
indicator_untrained2[final_bad_lm_head] = True
|
||||
|
||||
# Combine both checks
|
||||
indicator_untrained = indicator_untrained1 & indicator_untrained2
|
||||
|
||||
# Remove pad token possibility
|
||||
if hasattr(tokenizer, "pad_token_id"):
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
|
||||
indicator_untrained[pad_token_id] = False
|
||||
|
||||
where_untrained = torch.where(indicator_untrained)[0]
|
||||
n_untrained = where_untrained.shape[0]
|
||||
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||
|
||||
# Get set and actual tokens
|
||||
where_untrained = where_untrained.tolist()
|
||||
if len(where_untrained) == 0:
|
||||
return
|
||||
|
||||
# Remove untrained indices where it's longer
|
||||
where_untrained_set = frozenset(where_untrained)
|
||||
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||
# Remove None items in actual_bad_tokens
|
||||
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
|
||||
|
||||
# Check if tokenizer and training datasets have bad tokens
|
||||
if_bad_first = False
|
||||
if_bad_second = False
|
||||
# Check tokenizer's chat template for any untrained tokens
|
||||
if chat_template is not None:
|
||||
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||
|
||||
if isinstance(train_dataset, datasets.IterableDataset):
|
||||
# Skip the check, since the code below assumes
|
||||
# an indexable dataset
|
||||
return
|
||||
|
||||
# Check the first 250, last 250 input_ids
|
||||
size_dataset = len(train_dataset)
|
||||
size = min(size_dataset, 250)
|
||||
for j in range(size):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||
if if_bad:
|
||||
if_bad_second = True
|
||||
break
|
||||
|
||||
# Check last 250
|
||||
if not if_bad_second:
|
||||
left = max(size_dataset - 250, 0)
|
||||
for j in range(left, size_dataset):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||
if if_bad:
|
||||
if_bad_second = True
|
||||
break
|
||||
|
||||
# Check if bad tokens exists!
|
||||
if not if_bad_first and not if_bad_second:
|
||||
return
|
||||
|
||||
# Check if lm_head / embed_token are trainable!
|
||||
bad_not_trainable = False
|
||||
if not embedding_matrix.requires_grad:
|
||||
bad_not_trainable = True
|
||||
if not lm_head_matrix.requires_grad:
|
||||
bad_not_trainable = True
|
||||
|
||||
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
|
||||
final_bad_items = []
|
||||
|
||||
# Re-check the first 250, last 250 input_ids
|
||||
size_dataset = len(train_dataset)
|
||||
size = min(size_dataset, 250)
|
||||
for j in range(size):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
for item in input_ids:
|
||||
if item in where_untrained_set:
|
||||
final_bad_items.append(item)
|
||||
|
||||
# Re-check last 250
|
||||
left = max(size_dataset - 250, 0)
|
||||
for j in range(left, size_dataset):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
for item in input_ids:
|
||||
if item in where_untrained_set:
|
||||
final_bad_items.append(item)
|
||||
|
||||
# If no bad tokens, possibly chat template itself has issues?
|
||||
if len(final_bad_items) == 0:
|
||||
# Recheck 2000 and last 2000 items
|
||||
size_dataset = len(train_dataset)
|
||||
size = min(size_dataset, 2000)
|
||||
for j in range(size):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
for item in input_ids:
|
||||
if item in where_untrained_set:
|
||||
final_bad_items.append(item)
|
||||
|
||||
# Re-check last 2000
|
||||
left = max(size_dataset - 2000, 0)
|
||||
for j in range(left, size_dataset):
|
||||
input_ids = train_dataset[j]
|
||||
if "input_ids" in input_ids:
|
||||
input_ids = input_ids["input_ids"]
|
||||
for item in input_ids:
|
||||
if item in where_untrained_set:
|
||||
final_bad_items.append(item)
|
||||
|
||||
# Most likely false signal!
|
||||
if len(final_bad_items) == 0:
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
|
||||
)
|
||||
|
||||
# Count all the possible bad tokens
|
||||
final_counts = np.zeros(
|
||||
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
|
||||
)
|
||||
|
||||
def mapping(examples):
|
||||
input_ids = examples["input_ids"]
|
||||
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
|
||||
np.add.at(final_counts, counter, 1)
|
||||
|
||||
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||
|
||||
# Get counts for untrained tokens
|
||||
counts_untrained = final_counts[where_untrained]
|
||||
# Identify untrained tokens seen in train_dataset
|
||||
indices_seen_in_train = np.where(counts_untrained > 0)[0]
|
||||
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
|
||||
|
||||
if len(tokens_to_update) == 0:
|
||||
LOG.info(
|
||||
"No untrained tokens found in train_dataset. No embeddings were modified."
|
||||
)
|
||||
return
|
||||
|
||||
# Log the token IDs that are being rescaled
|
||||
LOG.info(
|
||||
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
|
||||
)
|
||||
|
||||
# Get sum of all items
|
||||
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||
|
||||
# Remove bad tokens
|
||||
sum_embedding -= torch.sum(
|
||||
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||
)
|
||||
sum_lm_head -= torch.sum(
|
||||
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||
)
|
||||
|
||||
# Find correct average by dividing by sum of trained tokens
|
||||
mean_embedding = sum_embedding / n_trained
|
||||
mean_lm_head = sum_lm_head / n_trained
|
||||
|
||||
# Compute scaling for tokens to update
|
||||
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
|
||||
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||
|
||||
# Prepare mean embeddings for tokens to update
|
||||
mean_embedding_repeated = (
|
||||
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||
)
|
||||
mean_lm_head_repeated = (
|
||||
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||
)
|
||||
|
||||
# Update embeddings only for tokens seen in train_dataset
|
||||
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
|
||||
embedding_matrix.dtype
|
||||
)
|
||||
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
|
||||
|
||||
# Clean up
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user