Compare commits
13 Commits
unsloth_mo
...
mixtral_sw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a58a9e5f6c | ||
|
|
279a1401b5 | ||
|
|
083beb6425 | ||
|
|
2ac1a72e4b | ||
|
|
23103ac5ac | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn]
|
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-axolotl.models.mixtral.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -65,19 +65,21 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -245,7 +247,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: state-spaces/mamba-2.8b
|
||||||
|
model_type: MambaLMHeadModel
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: true
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps:
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
save_steps: 0.25
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
tokens:
|
||||||
|
save_safetensors: False
|
||||||
79
examples/mistral/mixtral.yml
Normal file
79
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
base_model: DiscoResearch/mixtral-7b-8expert
|
||||||
|
model_type: MixtralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
#lora_target_modules:
|
||||||
|
# - gate
|
||||||
|
# - q_proj
|
||||||
|
# - k_proj
|
||||||
|
# - v_proj
|
||||||
|
# - o_proj
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - w3
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps:
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero2.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.35.2
|
transformers @ git+https://github.com/huggingface/transformers.git@df5c5c62ae253055336f5bb0828ca8e3e15ab6bd
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
|
|||||||
5
setup.py
5
setup.py
@@ -46,10 +46,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn>=2.3.0",
|
"flash-attn==2.3.3",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.0.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
|
|||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import (
|
||||||
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
MambaDataCollator,
|
||||||
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
Extend the base TrainingArguments for axolotl helpers
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
lr_quadratic_warmup: bool = field(
|
lr_quadratic_warmup: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
@@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Mamba specific trainer to handle loss calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
@@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return OneCycleLRSchedulerTrainer
|
return OneCycleLRSchedulerTrainer
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return AxolotlMambaTrainer
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
if self.cfg.save_safetensors:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
@@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
data_collator=self.build_collator(**data_collator_kwargs),
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**data_collator_kwargs,
|
|
||||||
),
|
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
|
return BatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Modeling module for Mamba models
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def fix_mamba_attn_for_loss():
|
||||||
|
from mamba_ssm.models import mixer_seq_simple
|
||||||
|
|
||||||
|
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||||
|
|
||||||
|
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||||
|
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
HF Transformers MambaConfig
|
||||||
|
"""
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
modeling configuration for state space model/mamba
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "mamba"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50280,
|
||||||
|
d_model=2560,
|
||||||
|
n_layer=64,
|
||||||
|
rms_norm=True,
|
||||||
|
residual_in_fp32=True,
|
||||||
|
fused_add_norm=True,
|
||||||
|
pad_vocab_size_multiple=8,
|
||||||
|
pad_token_id=50277,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=0,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.rms_norm = rms_norm
|
||||||
|
self.residual_in_fp32 = residual_in_fp32
|
||||||
|
self.fused_add_norm = fused_add_norm
|
||||||
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||||
|
from mamba_ssm.utils.generation import GenerationMixin
|
||||||
|
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_layer: int,
|
||||||
|
vocab_size: int,
|
||||||
|
initializer_cfg=None,
|
||||||
|
pad_vocab_size_multiple: int = 1,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
**backbone_kwargs,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if vocab_size % pad_vocab_size_multiple != 0:
|
||||||
|
vocab_size += pad_vocab_size_multiple - (
|
||||||
|
vocab_size % pad_vocab_size_multiple
|
||||||
|
)
|
||||||
|
self.config = MambaConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||||
|
)
|
||||||
|
self.backbone = MixerModel(
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
initializer_cfg=initializer_cfg,
|
||||||
|
**backbone_kwargs,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.apply(
|
||||||
|
partial(
|
||||||
|
_init_weights,
|
||||||
|
n_layer=n_layer,
|
||||||
|
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
self.lm_head.weight = self.backbone.embedding.weight
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||||
|
return self.backbone.allocate_inference_cache(
|
||||||
|
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids=None,
|
||||||
|
inference_params=None,
|
||||||
|
num_last_tokens=0,
|
||||||
|
labels=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||||
|
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||||
|
"""
|
||||||
|
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||||
|
if num_last_tokens > 0:
|
||||||
|
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
logits = lm_logits
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||||
|
print(loss)
|
||||||
|
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||||
|
|
||||||
|
else:
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
|
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||||
|
config = load_config_hf(pretrained_model_name)
|
||||||
|
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||||
|
model.load_state_dict(
|
||||||
|
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||||
|
)
|
||||||
|
return model
|
||||||
9
src/axolotl/models/mixtral/__init__.py
Normal file
9
src/axolotl/models/mixtral/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
Custom modeling code for mixtral
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .configuration_moe_mistral import MixtralConfig # noqa
|
||||||
|
from .modeling_moe_mistral import ( # noqa
|
||||||
|
MixtralForCausalLM,
|
||||||
|
replace_mixtral_mlp_with_swiglu,
|
||||||
|
)
|
||||||
154
src/axolotl/models/mixtral/configuration_moe_mistral.py
Normal file
154
src/axolotl/models/mixtral/configuration_moe_mistral.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Mistral model configuration"""
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
|
||||||
|
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
|
||||||
|
|
||||||
|
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||||
|
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MistralModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
||||||
|
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
|
||||||
|
allows sequence of up to 4096*32 tokens.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
sliding_window (`int`, *optional*, defaults to 4096):
|
||||||
|
Sliding window attention window size. If not specified, will default to `4096`.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MistralModel, MistralConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Mistral 7B style configuration
|
||||||
|
>>> configuration = MixtralConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the Mistral 7B style configuration
|
||||||
|
>>> model = MixtralModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "mistral"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_experts_per_token=2,
|
||||||
|
num_experts=8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.num_experts_per_token = num_experts_per_token
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
1561
src/axolotl/models/mixtral/modeling_moe_mistral.py
Normal file
1561
src/axolotl/models/mixtral/modeling_moe_mistral.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,168 +0,0 @@
|
|||||||
# Adapted from Unsloth
|
|
||||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
import torch
|
|
||||||
|
|
||||||
MAX_FUSED_SIZE = 65536
|
|
||||||
|
|
||||||
def calculate_settings(n):
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
|
||||||
# CUDA only supports 65536 - 2^16 threads per block
|
|
||||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
||||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
|
||||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
|
||||||
num_warps = 4
|
|
||||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
|
||||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
|
||||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
|
||||||
return BLOCK_SIZE, num_warps
|
|
||||||
pass
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
|
||||||
loss_ptr,
|
|
||||||
lse_ptr,
|
|
||||||
labels_ptr,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE: tl.constexpr,):
|
|
||||||
"""
|
|
||||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
|
||||||
Pi = exp(xi) / sum(exp(xi))
|
|
||||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
|
||||||
= -y [ x - log[sum(exp(x))] ]
|
|
||||||
= y * (log[sum(exp(x))] - x)
|
|
||||||
If y == 0: CE_i = 0
|
|
||||||
If y == 1: CE_i = logsumexp - x
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0)
|
|
||||||
logits_ptr += row_idx * logits_row_stride
|
|
||||||
loss_ptr += row_idx
|
|
||||||
lse_ptr += row_idx
|
|
||||||
labels_ptr += row_idx
|
|
||||||
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
# TODO: Fixup int32 locations to int64
|
|
||||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
|
||||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
|
||||||
max_logits = tl.max(logits, 0)
|
|
||||||
# Maximum stops overflow
|
|
||||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
|
||||||
tl.store(lse_ptr, lse)
|
|
||||||
|
|
||||||
if label_idx != -100:
|
|
||||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
|
||||||
loss = lse - logits_label
|
|
||||||
else:
|
|
||||||
loss = 0.0
|
|
||||||
tl.store(loss_ptr, loss)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
|
||||||
dloss_ptr, dloss_row_stride,
|
|
||||||
lse_ptr,
|
|
||||||
labels_ptr,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE: tl.constexpr,):
|
|
||||||
"""
|
|
||||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
|
||||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
|
||||||
|
|
||||||
From https://en.wikipedia.org/wiki/LogSumExp
|
|
||||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
|
||||||
|
|
||||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
|
||||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
|
||||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
|
||||||
|
|
||||||
If y == 0: dC/dx = 0
|
|
||||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
|
||||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0)
|
|
||||||
logits_ptr += row_idx * logits_row_stride
|
|
||||||
dloss_ptr += row_idx * dloss_row_stride
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
# TODO: Fixup int32 locations to int64
|
|
||||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
|
||||||
|
|
||||||
if label_idx != -100:
|
|
||||||
dloss = tl.load(dloss_ptr)
|
|
||||||
else:
|
|
||||||
dloss = 0.0
|
|
||||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
|
||||||
lse = tl.load(lse_ptr + row_idx)
|
|
||||||
probs = tl.exp(logits - lse)
|
|
||||||
|
|
||||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
|
||||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, logits, labels):
|
|
||||||
n_rows, n_cols = logits.shape
|
|
||||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
||||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
|
||||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
|
||||||
|
|
||||||
_cross_entropy_forward[(n_rows,)](
|
|
||||||
logits, logits.stride(0),
|
|
||||||
losses,
|
|
||||||
logsumexp,
|
|
||||||
labels,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE = BLOCK_SIZE,
|
|
||||||
num_warps = num_warps,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
||||||
ctx.num_warps = num_warps
|
|
||||||
ctx.save_for_backward(logits, logsumexp, labels)
|
|
||||||
return losses
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dlosses):
|
|
||||||
logits, logsumexp, labels = ctx.saved_tensors
|
|
||||||
n_rows, n_cols = logits.shape
|
|
||||||
|
|
||||||
_cross_entropy_backward[(n_rows,)](
|
|
||||||
logits, logits.stride(0),
|
|
||||||
dlosses, dlosses.stride(0),
|
|
||||||
logsumexp,
|
|
||||||
labels,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
|
||||||
num_warps = ctx.num_warps,
|
|
||||||
)
|
|
||||||
return logits, None, None,
|
|
||||||
pass
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def fast_cross_entropy_loss(logits, labels):
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
logits: (batch, seq_len, vocab_size)
|
|
||||||
labels: (batch, seq_len,)
|
|
||||||
Returns:
|
|
||||||
losses: float
|
|
||||||
"""
|
|
||||||
batch, seq_len, d = logits.shape
|
|
||||||
assert(labels.shape == (batch, seq_len))
|
|
||||||
|
|
||||||
loss = CrossEntropyLoss.apply(
|
|
||||||
logits.view(batch*seq_len, d),
|
|
||||||
labels.view(-1),
|
|
||||||
)
|
|
||||||
n_items = torch.count_nonzero(labels != -100)
|
|
||||||
return loss.sum() / n_items
|
|
||||||
pass
|
|
||||||
@@ -13,20 +13,16 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
flash_attn_varlen_kvpacked_func,
|
flash_attn_varlen_kvpacked_func,
|
||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralAttention as OriginalMistralAttention,
|
MistralAttention as OriginalMistralAttention,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
@@ -40,9 +36,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
flashattn_forward
|
flashattn_forward
|
||||||
)
|
)
|
||||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
|
||||||
mistral_causallm_forward
|
|
||||||
)
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
MistralDecoderLayer
|
MistralDecoderLayer
|
||||||
@@ -648,71 +641,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
|||||||
outputs += (present_key_value,)
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def mistral_causallm_forward(
|
|
||||||
self: OriginalMistralForCausalLM,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
*args, **kwargs
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
shift_logits = logits
|
|
||||||
if not hasattr(self, "extra_ignored_labels"):
|
|
||||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
|
||||||
|
|
||||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
|
|
||||||
# FAST CROSS ENTROPY
|
|
||||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
@@ -13,7 +13,7 @@ register_conv_template(
|
|||||||
system_message="You are a helpful assistant.",
|
system_message="You are a helpful assistant.",
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>\n",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,8 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
model.config.use_cache = False
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
@@ -92,7 +93,8 @@ def train(
|
|||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|||||||
@@ -2,12 +2,16 @@
|
|||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaDataCollator:
|
||||||
|
"""
|
||||||
|
Collator for State Space Models (Mamba)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple(
|
||||||
|
[torch.LongTensor(instance[key]) for instance in instances]
|
||||||
|
for key in ("input_ids", "labels")
|
||||||
|
)
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -52,9 +54,26 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
model_config = AutoConfig.from_pretrained(
|
model_type = cfg.model_type
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
|
||||||
)
|
if model_type == "MixtralForCausalLM":
|
||||||
|
from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
|
||||||
|
|
||||||
|
model_config = MixtralConfig.from_pretrained(model_config_name)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_config = AutoConfig.from_pretrained(
|
||||||
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except ValueError as err:
|
||||||
|
if "mamba" in model_config_name:
|
||||||
|
return addict.Dict(
|
||||||
|
{
|
||||||
|
"model_type": "mamba",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
raise err
|
||||||
|
|
||||||
if cfg.model_config:
|
if cfg.model_config:
|
||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
@@ -92,6 +111,7 @@ def load_tokenizer(cfg):
|
|||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
"LlamaTokenizerFast",
|
"LlamaTokenizerFast",
|
||||||
"CodeLlamaTokenizer",
|
"CodeLlamaTokenizer",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
]
|
]
|
||||||
and hasattr(tokenizer, "pad_token")
|
and hasattr(tokenizer, "pad_token")
|
||||||
and not tokenizer.pad_token
|
and not tokenizer.pad_token
|
||||||
@@ -124,6 +144,23 @@ def load_tokenizer(cfg):
|
|||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If we add bos_token and eos_token, we need to update the post processor to
|
||||||
|
# handle them correctly.
|
||||||
|
# https://github.com/huggingface/transformers/pull/24132
|
||||||
|
bos_or_eos_in_special_tokens = (
|
||||||
|
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
tokenizer.__class__.__name__
|
||||||
|
in (
|
||||||
|
"LlamaTokenizerFast",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
|
)
|
||||||
|
and bos_or_eos_in_special_tokens
|
||||||
|
):
|
||||||
|
tokenizer.update_post_processor()
|
||||||
|
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(
|
tokenizer.add_tokens(
|
||||||
[
|
[
|
||||||
@@ -271,7 +308,9 @@ def load_model(
|
|||||||
or cfg.is_falcon_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
):
|
):
|
||||||
model_kwargs["use_flash_attention_2"] = True
|
# TODO enable once properly supported in transformers
|
||||||
|
# model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
@@ -333,6 +372,37 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
elif model_type == "MixtralForCausalLM":
|
||||||
|
from axolotl.models.mixtral import (
|
||||||
|
MixtralForCausalLM,
|
||||||
|
replace_mixtral_mlp_with_swiglu,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_mlp:
|
||||||
|
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
||||||
|
replace_mixtral_mlp_with_swiglu(model)
|
||||||
|
|
||||||
|
elif model_type == "MambaLMHeadModel":
|
||||||
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
||||||
|
model_kwargs["device"] = torch.cuda.current_device()
|
||||||
|
del model_kwargs["torch_dtype"]
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
del model_kwargs["max_memory"]
|
||||||
|
|
||||||
|
model = MambaLMHeadModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -392,13 +462,17 @@ def load_model(
|
|||||||
if cfg.resize_token_embeddings_to_32x
|
if cfg.resize_token_embeddings_to_32x
|
||||||
else len(tokenizer)
|
else len(tokenizer)
|
||||||
)
|
)
|
||||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
if (
|
||||||
|
hasattr(model, "get_input_embeddings")
|
||||||
|
and model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
|
):
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
else:
|
else:
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -408,20 +482,22 @@ def load_model(
|
|||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "bos_token_id")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "bos_token_id")
|
||||||
and model.config.bos_token_id
|
and model.config.bos_token_id
|
||||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
model.config.bos_token_id = tokenizer.bos_token_id
|
model.config.bos_token_id = tokenizer.bos_token_id
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "eos_token_id")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "eos_token_id")
|
||||||
and model.config.eos_token_id
|
and model.config.eos_token_id
|
||||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
if hasattr(model, "device") and model.device.type == "cuda":
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
@@ -480,7 +556,8 @@ def load_model(
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
model.config.use_cache = False
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|||||||
@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
if (
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
||||||
|
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
||||||
|
or cfg.model_config_type == "mamba"
|
||||||
):
|
):
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
if not cfg.total_supervised_tokens:
|
skip_estimates = cfg.model_config_type == "mamba"
|
||||||
|
|
||||||
|
if not skip_estimates and not cfg.total_supervised_tokens:
|
||||||
total_supervised_tokens = (
|
total_supervised_tokens = (
|
||||||
train_dataset.data.column("labels")
|
train_dataset.data.column("labels")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if not skip_estimates and cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
|
|
||||||
|
|||||||
65
tests/e2e/test_mamba.py
Normal file
65
tests/e2e/test_mamba.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "state-spaces/mamba-130m",
|
||||||
|
"model_type": "MambaLMHeadModel",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": False,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"gradient_checkpointing": False,
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": None,
|
||||||
|
"save_safetensors": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
Reference in New Issue
Block a user