Phi2 rewrite (#1058)
* restore to current phi modeling code from phi-2 * enable gradient checkpointing * don't cast everything to float32 all the time * gradient checkpointing for phi2 ParallelBlock module too * fix enabling flash attn for phi2 * add comment about import * fix phi2 example * fix model type check for tokenizer * revert float32 -> bf16 casting changes * support fused dense flash attn * fix the repo for flash-attn * add package name for subdir pkg * fix the data collator when not using sample packing * install packaging for pytests in ci * also fix setup to not install flash attn fused dense subdir if not extras * split out the fused-dense-lib in extra requires * don't train w group_by_length for phi * update integration test to use phi2 * set max steps and save steps for phi e2e tests * try to workaround ssave issue in ci * skip phi2 e2e test for now
This commit is contained in:
73
examples/phi/phi2-ft.yml
Normal file
73
examples/phi/phi2-ft.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
base_model: microsoft/phi-2
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false # currently unsupported
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embd
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_8bit
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
resize_token_embeddings_to_32x: true
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -12,6 +12,7 @@ fire
|
||||
PyYAML>=6.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
|
||||
4
setup.py
4
setup.py
@@ -17,6 +17,7 @@ def parse_requirements():
|
||||
_dependency_links.append(url)
|
||||
elif (
|
||||
"flash-attn" not in line
|
||||
and "flash-attention" not in line
|
||||
and "deepspeed" not in line
|
||||
and line
|
||||
and line[0] != "#"
|
||||
@@ -51,6 +52,9 @@ setup(
|
||||
"flash-attn": [
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
],
|
||||
|
||||
@@ -34,6 +34,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
@@ -843,7 +844,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
if training_args.sample_packing:
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
|
||||
@@ -9,27 +9,32 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from .configuration_phi import PhiConfig
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
||||
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except: # noqa: E722
|
||||
except ImportError:
|
||||
pad_input, unpad_input = None, None
|
||||
FlashRotaryEmbedding = None
|
||||
FlashSelfAttention, FlashCrossAttention = None, None
|
||||
|
||||
# this is in a seperate try/except block since sometimes fused_dense isn't available
|
||||
# and it shouldn't completely disable flash attn when it isn't
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except ImportError:
|
||||
FusedDense = None
|
||||
|
||||
|
||||
@@ -224,7 +229,9 @@ class RotaryEmbedding(nn.Module):
|
||||
|
||||
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
||||
self._update_cos_sin_cache(
|
||||
max_position_embeddings, device=device, dtype=torch.float32
|
||||
max_position_embeddings,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
||||
@@ -281,34 +288,32 @@ class RotaryEmbedding(nn.Module):
|
||||
seqlen_offset: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_start = seqlen_offset
|
||||
seq_end = seq_start + qkv.shape[1]
|
||||
|
||||
if (
|
||||
self._cos_cached.device != qkv.device
|
||||
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
||||
or self._cos_cached.device != qkv.device
|
||||
or self._cos_cached.dtype != qkv.dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._update_cos_sin_cache(
|
||||
self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype
|
||||
qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
|
||||
)
|
||||
|
||||
if kv is None:
|
||||
return _apply_rotary_emb_qkv(
|
||||
qkv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
else:
|
||||
q = _apply_rotary_emb(
|
||||
qkv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
kv = _apply_rotary_emb_kv(
|
||||
kv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
|
||||
return q, kv
|
||||
@@ -511,7 +516,7 @@ def _update_kv_cache(
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
|
||||
inference_params.max_batch_size,
|
||||
inference_params.max_seqlen,
|
||||
2,
|
||||
@@ -520,9 +525,6 @@ def _update_kv_cache(
|
||||
dtype=kv.dtype,
|
||||
device=kv.device,
|
||||
)
|
||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||
else:
|
||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
@@ -530,8 +532,19 @@ def _update_kv_cache(
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
# When the current sequence length is equal to or larger than the maximum sequence length,
|
||||
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
||||
if sequence_end >= inference_params.max_seqlen:
|
||||
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
|
||||
(inference_params.key_value_memory_dict[layer_idx], kv), dim=1
|
||||
)
|
||||
|
||||
inference_params.key_value_memory_dict[layer_idx][
|
||||
batch_start:batch_end, sequence_start:sequence_end, ...
|
||||
] = kv
|
||||
kv = inference_params.key_value_memory_dict[layer_idx][
|
||||
batch_start:batch_end, :sequence_end, ...
|
||||
]
|
||||
|
||||
return kv
|
||||
|
||||
@@ -624,13 +637,10 @@ class MHA(nn.Module):
|
||||
self.layer_idx = layer_idx
|
||||
self.return_residual = return_residual
|
||||
self.checkpointing = checkpointing
|
||||
self._gradient_checkpointing_func = None
|
||||
|
||||
def _forward_self_attn(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
key_padding_mask: Optional[torch.BoolTensor],
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
||||
) -> torch.FloatTensor:
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(
|
||||
@@ -643,20 +653,21 @@ class MHA(nn.Module):
|
||||
if self.flash_attn:
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and cu_seqlens is None
|
||||
and max_seqlen is None
|
||||
):
|
||||
cu_seqlens, max_seqlen = None, None
|
||||
if key_padding_mask is not None:
|
||||
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
||||
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
||||
qkv, indices, cu_seqlens, max_seqlen = unpad_input(
|
||||
qkv, key_padding_mask
|
||||
)
|
||||
|
||||
if self.checkpointing:
|
||||
attn_output = torch.utils.checkpoint.checkpoint(
|
||||
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
||||
if self.checkpointing and self.training:
|
||||
attn_output = self._gradient_checkpointing_func(
|
||||
self.inner_attn,
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
attn_output = self.inner_attn(
|
||||
@@ -670,9 +681,12 @@ class MHA(nn.Module):
|
||||
else attn_output
|
||||
)
|
||||
|
||||
if self.checkpointing:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
self.inner_attn, qkv, key_padding_mask=key_padding_mask
|
||||
if self.checkpointing and self.training:
|
||||
return self._gradient_checkpointing_func(
|
||||
self.inner_attn,
|
||||
qkv,
|
||||
key_padding_mask=key_padding_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
||||
@@ -725,8 +739,8 @@ class MHA(nn.Module):
|
||||
q, key_padding_mask
|
||||
)
|
||||
|
||||
if self.checkpointing:
|
||||
attn_output = torch.utils.checkpoint.checkpoint(
|
||||
if self.checkpointing and self.training:
|
||||
attn_output = self._gradient_checkpointing_func(
|
||||
self.inner_cross_attn,
|
||||
q,
|
||||
kv,
|
||||
@@ -735,6 +749,7 @@ class MHA(nn.Module):
|
||||
max_seqlen=max_seqlen_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
attn_output = self.inner_cross_attn(
|
||||
@@ -753,13 +768,14 @@ class MHA(nn.Module):
|
||||
else attn_output
|
||||
)
|
||||
|
||||
if self.checkpointing:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
if self.checkpointing and self.training:
|
||||
return self._gradient_checkpointing_func(
|
||||
self.inner_cross_attn,
|
||||
q,
|
||||
kv,
|
||||
key_padding_mask=key_padding_mask,
|
||||
causal=causal,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return self.inner_cross_attn(
|
||||
@@ -771,11 +787,8 @@ class MHA(nn.Module):
|
||||
x: torch.FloatTensor,
|
||||
past_key_values: Optional[InferenceParams] = None,
|
||||
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.bool()
|
||||
else:
|
||||
@@ -785,18 +798,12 @@ class MHA(nn.Module):
|
||||
if self.n_head == self.n_head_kv:
|
||||
if past_key_values is None:
|
||||
# If `past_key_values` are not supplied, we run self-attention
|
||||
attn_output = self._forward_self_attn(
|
||||
x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
||||
)
|
||||
attn_output = self._forward_self_attn(x, attention_mask)
|
||||
else:
|
||||
# If `past_key_values` are supplied, it means that we might have cached values and
|
||||
# could take advantage of cross-attention
|
||||
attn_output = self._forward_cross_attn(
|
||||
x,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
x, past_key_values, attention_mask
|
||||
)
|
||||
# MQA / GQA
|
||||
else:
|
||||
@@ -830,6 +837,8 @@ class ParallelBlock(nn.Module):
|
||||
|
||||
self.mixer = MHA(config, layer_idx=block_idx)
|
||||
self.mlp = MLP(config)
|
||||
self.checkpointing = False
|
||||
self._gradient_checkpointing_func = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -838,23 +847,52 @@ class ParallelBlock(nn.Module):
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln(hidden_states)
|
||||
|
||||
attn_outputs = self.mixer(
|
||||
def _forward(
|
||||
mixer,
|
||||
resid_dropout,
|
||||
mlp,
|
||||
ln,
|
||||
hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = ln(hidden_states)
|
||||
|
||||
attn_outputs = mixer(
|
||||
hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
if isinstance(attn_outputs, tuple):
|
||||
attn_outputs = attn_outputs[0]
|
||||
|
||||
attn_outputs = resid_dropout(attn_outputs)
|
||||
feed_forward_hidden_states = resid_dropout(mlp(hidden_states))
|
||||
|
||||
return attn_outputs + feed_forward_hidden_states + residual
|
||||
|
||||
if self.training and self.checkpointing:
|
||||
return self._gradient_checkpointing_func(
|
||||
_forward,
|
||||
self.mixer,
|
||||
self.resid_dropout,
|
||||
self.mlp,
|
||||
self.ln,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
return _forward(
|
||||
self.mixer,
|
||||
self.resid_dropout,
|
||||
self.mlp,
|
||||
self.ln,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
)
|
||||
if isinstance(attn_outputs, tuple):
|
||||
attn_outputs = attn_outputs[0]
|
||||
|
||||
attn_outputs = self.resid_dropout(attn_outputs)
|
||||
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
||||
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalLMHead(nn.Module):
|
||||
@@ -911,7 +949,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = PhiConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = False
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ParallelBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs) -> None:
|
||||
@@ -931,6 +969,14 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(
|
||||
self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint
|
||||
):
|
||||
for module in self.modules():
|
||||
if hasattr(module, "checkpointing"):
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.checkpointing = enable
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -951,7 +997,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
else:
|
||||
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
||||
past_key_values.seqlen_offset = len(input_ids[0]) - 1
|
||||
past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
@@ -988,8 +1034,6 @@ class PhiModel(PhiPreTrainedModel):
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.embd(input_ids)
|
||||
|
||||
@@ -998,8 +1042,6 @@ class PhiModel(PhiPreTrainedModel):
|
||||
hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@@ -1034,23 +1076,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
cu_seqlens: Optional[torch.LongTensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
if position_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
input_ids, past_key_values=past_key_values, attention_mask=attention_mask
|
||||
)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
if not model_config_name and cfg.tokenizer_config:
|
||||
model_config_name = cfg.tokenizer_config
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
|
||||
try:
|
||||
@@ -80,6 +82,7 @@ def load_model_config(cfg):
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
@@ -139,6 +142,7 @@ def load_tokenizer(cfg):
|
||||
for k, val in cfg.special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and cfg.adapter
|
||||
@@ -149,6 +153,7 @@ def load_tokenizer(cfg):
|
||||
for x in ["embed_tokens", "lm_head"]
|
||||
)
|
||||
)
|
||||
and (model_config.model_type in ["llama", "mistral", "mixtral"])
|
||||
):
|
||||
raise ValueError(
|
||||
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
||||
@@ -386,6 +391,10 @@ def load_model(
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
if model_config.model_type == "phi-msft":
|
||||
model_config.flash_attn = True
|
||||
model_config.flash_rotary = True
|
||||
model_config.fused_dense = True
|
||||
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
@@ -438,11 +447,12 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type == "PhiForCausalLM":
|
||||
elif model_type == "PhiForCausalLM" or model_config.model_type == "phi-msft":
|
||||
from axolotl.models.phi import PhiForCausalLM
|
||||
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
|
||||
@@ -7,6 +7,8 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -21,17 +23,18 @@ os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
class TestPhi(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
Test case for Phi2 models
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="fixme later")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
def test_phi2_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "microsoft/phi-2",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "PhiForCausalLM",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"sample_packing": False,
|
||||
@@ -39,9 +42,6 @@ class TestPhi(unittest.TestCase):
|
||||
"adapter": None,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<|endoftext|>",
|
||||
"bos_token": "<|endoftext|>",
|
||||
"eos_token": "<|endoftext|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
@@ -57,9 +57,14 @@ class TestPhi(unittest.TestCase):
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"bf16": True,
|
||||
"flash_attention": True,
|
||||
"max_steps": 10,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -69,12 +74,13 @@ class TestPhi(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@pytest.mark.skip(reason="multipack no longer supported atm")
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "microsoft/phi-2",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
|
||||
Reference in New Issue
Block a user