Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
080612219b use even if not using sample packing 2023-10-13 17:54:35 -04:00
Wing Lian
f95858d369 alternate impl of NEFT 2023-10-13 17:45:24 -04:00
43 changed files with 149 additions and 509 deletions

View File

@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--peft_model_dir="./lora-out" --lora_model_dir="./lora-out"
``` ```
## Installation ## Installation
@@ -297,24 +297,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
For a dataset that is preprocessed for instruction purposes: Using yaml. Example:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
field_system: system no_input_format: |-
format: "[INST] {instruction} [/INST]" User: {instruction}<|end_of_turn|>
no_input_format: "[INST] {instruction} [/INST]" Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
@@ -384,10 +385,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora - lora
```yaml ```yaml
adapter: lora # qlora or leave blank for full finetune adapter: lora # qlora or leave blank for full finetune
peft_r: 8 lora_r: 8
peft_alpha: 16 lora_alpha: 16
peft_dropout: 0.05 lora_dropout: 0.05
peft_target_modules: lora_target_modules:
- q_proj - q_proj
- v_proj - v_proj
``` ```
@@ -531,15 +532,15 @@ total_num_tokens:
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
peft_model_dir: lora_model_dir:
# LoRA hyperparameters # LoRA hyperparameters
# For more details about the following options, see: # For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
peft_r: 8 lora_r: 8
peft_alpha: 16 lora_alpha: 16
peft_dropout: 0.05 lora_dropout: 0.05
peft_target_modules: lora_target_modules:
- q_proj - q_proj
- v_proj - v_proj
# - k_proj # - k_proj
@@ -547,13 +548,13 @@ peft_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
peft_target_linear: # if true, will target all linear layers lora_target_linear: # If true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
peft_modules_to_save: lora_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
@@ -561,8 +562,7 @@ peft_modules_to_save:
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model. # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
peft_fan_in_fan_out: false lora_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
@@ -870,7 +870,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir" python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
@@ -891,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
Add below flag to train command above Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with

View File

@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 16 lora_r: 16

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 16 lora_r: 16

View File

@@ -22,7 +22,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 64 lora_r: 64

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 512 sequence_len: 512
max_packed_sequence_len: max_packed_sequence_len:
lora_r: lora_r:

View File

@@ -18,7 +18,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: sample_packing:
lora_r: 8 lora_r: 8

View File

@@ -1,72 +0,0 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sequence_len: 4096
sample_packing: true sample_packing: true
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: lora_r:

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -22,7 +22,7 @@ sample_packing: true
pad_to_sequence_len: pad_to_sequence_len:
adapter: adapter:
peft_model_dir: lora_model_dir:
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:

View File

@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len: pad_to_sequence_len:
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
lora_r: 64 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -13,7 +13,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 64 lora_r: 64

View File

@@ -7,7 +7,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
sequence_len: 512 sequence_len: 512
lora_r: 16 lora_r: 16
lora_alpha: 32 lora_alpha: 32

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -8,7 +8,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
peft_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -20,7 +20,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
peft_model_dir: lora_model_dir:
sequence_len: 8192 sequence_len: 8192
max_packed_sequence_len: max_packed_sequence_len:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -46,7 +46,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn>=2.3.0", "flash-attn>=2.2.1",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",

View File

@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False, rms_norm: Optional[bool] = False,
noisy_embeddings_alpha: Optional[int] = False,
): ):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask _prepare_decoder_attention_mask
) )
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if noisy_embeddings_alpha:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
)
else:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
llama_model_get_inputs_embeds
)
if packed: if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = ( transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -116,8 +126,6 @@ def flashattn_forward(
attention_mask: [bsz, q_len] attention_mask: [bsz, q_len]
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): if not hasattr(self, "pretraining_tp"):
@@ -153,13 +161,6 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
@@ -318,10 +319,6 @@ def flashattn_forward(
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@@ -424,6 +421,28 @@ def generate_qkv(
) )
def llama_model_get_inputs_embeds(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
noisy_embeddings_alpha: Optional[int] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
if noisy_embeddings_alpha:
input_mask = attention_mask.to(inputs_embeds) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * inputs_embeds.size(-1)
mag = noisy_embeddings_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
inputs_embeds += delta
return inputs_embeds
def llama_model_forward( def llama_model_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@@ -490,7 +509,8 @@ def llama_model_forward(
cu_seqlens = cu_seqlens.squeeze() cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(
@@ -515,7 +535,6 @@ def llama_model_forward(
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
@@ -573,10 +592,6 @@ def llama_model_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
) )
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask return attention_mask
def flashattn_forward( def flashattn_forward(
self: OriginalMistralAttention, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -150,41 +91,10 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # reuse k, v, self_attention
if ( key_states = torch.cat([past_key_value[0], key_states], dim=2)
hasattr(self.config, "sliding_window") value_states = torch.cat([past_key_value[1], value_states], dim=2)
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@@ -210,13 +120,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -242,7 +146,6 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
@@ -254,7 +157,6 @@ def flashattn_forward(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
else: else:
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
@@ -289,7 +191,6 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)

View File

@@ -1,6 +1,6 @@
"""Module for Alpaca prompt strategy classes""" """Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional, Tuple from typing import Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style), AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -121,18 +121,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -202,10 +190,7 @@ def validate_config(cfg):
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):

View File

@@ -158,7 +158,7 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError): except FileNotFoundError:
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model) replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
@@ -147,6 +151,7 @@ def load_model(
packed=cfg.sample_packing, packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy, cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
) )
elif cfg.is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -180,16 +185,16 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: # if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import ( # from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution, # replace_llama_embeddings_with_uniform_distribution,
) # )
#
LOG.info("patching with noisy embeddings") # LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution( # replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha # noise_alpha=cfg.noisy_embedding_alpha
) # )
#
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import ( from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution, replace_mistral_embeddings_with_uniform_distribution,
@@ -406,21 +411,21 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
require_peft: bool = False needs_fa2_dtype = cfg.adapter or cfg.fsdp
if cfg.adapter in ["lora", "qlora", "ia3"]: if (cfg.adapter == "lora" and load_in_8bit) or (
require_peft = True cfg.adapter == "qlora" and cfg.load_in_4bit
):
if require_peft:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model): if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -429,7 +434,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(cfg.torch_dtype) module.to(cfg.torch_dtype)
model, peft_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -460,7 +465,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device) log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, peft_config return model, lora_config
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(model, cfg, adapter, inference=False):
@@ -470,8 +475,6 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -490,11 +493,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.peft_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter") LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.peft_model_dir, cfg.lora_model_dir,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
else: else:
@@ -507,20 +510,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
peft_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if isinstance(module, cls) or "Linear" in module.__class__.__name__:
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".") names = name.split(".")
peft_module_names.add(names[0] if len(names) == 1 else names[-1]) lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in peft_module_names: # needed for 16-bit if "lm_head" in lora_module_names: # needed for 16-bit
peft_module_names.remove("lm_head") lora_module_names.remove("lm_head")
return list(peft_module_names) return list(lora_module_names)
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
@@ -528,68 +527,34 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
peft_target_modules = list(cfg.peft_target_modules or []) lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.peft_target_linear: if cfg.lora_target_linear:
linear_names = find_all_linear_names(model) linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
peft_target_modules = list(set(peft_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
peft_config = LoraConfig( lora_config = LoraConfig(
r=cfg.peft_r, r=cfg.lora_r,
lora_alpha=cfg.peft_alpha, lora_alpha=cfg.lora_alpha,
target_modules=peft_target_modules, target_modules=lora_target_modules,
lora_dropout=cfg.peft_dropout, lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.peft_fan_in_fan_out, fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.peft_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.peft_model_dir, cfg.lora_model_dir,
is_trainable=(not inference), is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, peft_config) model = get_peft_model(model, lora_config)
model.print_trainable_parameters() model.print_trainable_parameters()
return model, peft_config return model, lora_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config

View File

@@ -423,9 +423,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( if "CodeGenTokenizer" in tokenizer.__class__.__name__:
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")

View File

@@ -24,10 +24,6 @@ class TestLoraLlama(unittest.TestCase):
""" """
def test_lora(self): def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
@@ -70,101 +66,6 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): def test_lora_packing(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()

View File

@@ -1,48 +0,0 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64