Compare commits
1 Commits
pytest-ski
...
20240404-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05f7034288 |
@@ -38,8 +38,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg.load_in_4bit = False
|
parsed_cfg.load_in_4bit = False
|
||||||
parsed_cfg.load_in_8bit = False
|
parsed_cfg.load_in_8bit = False
|
||||||
parsed_cfg.flash_attention = False
|
parsed_cfg.flash_attention = False
|
||||||
parsed_cfg.deepspeed = None
|
|
||||||
parsed_cfg.fsdp = None
|
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
PreTrainedModel,
|
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -803,15 +802,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
def tokenize_row(
|
|
||||||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
|
||||||
) -> Dict:
|
|
||||||
res = super().tokenize_row(feature, model=model)
|
|
||||||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
|
||||||
res[key] = res[key][1:]
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
self.switch_active_layers(state)
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# Check if it's time to switch active layers, including at step 0
|
# Check if it's time to switch active layers, including at step 0
|
||||||
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
if state.global_step % self.step_interval == 0:
|
||||||
self.switch_active_layers()
|
self.switch_active_layers(state)
|
||||||
|
|
||||||
def switch_active_layers(self):
|
def switch_active_layers(self, state):
|
||||||
# First, disable gradients for all layers
|
# First, disable gradients for all layers
|
||||||
self.freeze_all_layers()
|
self.freeze_all_layers()
|
||||||
|
|
||||||
|
deterministic_seed = state.global_step
|
||||||
|
np.random.seed(deterministic_seed)
|
||||||
|
|
||||||
# Randomly select n_layers to activate
|
# Randomly select n_layers to activate
|
||||||
layers = reduce(
|
layers = reduce(
|
||||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
)
|
)
|
||||||
self.active_layers_indices = np.random.choice(
|
self.active_layers_indices = np.random.choice(
|
||||||
range(self.total_layers), self.n_layers, replace=False
|
range(self.total_layers),
|
||||||
|
self.n_layers,
|
||||||
|
replace=False,
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ def chat_templates(user_choice: str):
|
|||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Module for pydantic models for configuration
|
Module for pydantic models for configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -141,7 +140,6 @@ class ChatTemplate(str, Enum):
|
|||||||
chatml = "chatml" # pylint: disable=invalid-name
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
inst = "inst" # pylint: disable=invalid-name
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -242,6 +240,17 @@ class LoraConfig(BaseModel):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_quantized_dora(cls, data):
|
||||||
|
if data.get("peft_use_dora") and (
|
||||||
|
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`peft_use_dora` is not currently compatible with quantized weights."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
@@ -645,20 +654,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_sample_packing_wo_flash(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("sample_packing")
|
|
||||||
and not data.get("flash_attention")
|
|
||||||
and not data.get("sdp_attention")
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_rl(cls, data):
|
def check_sample_packing_w_rl(cls, data):
|
||||||
|
|||||||
@@ -902,12 +902,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
try:
|
model.print_trainable_parameters()
|
||||||
model.print_trainable_parameters()
|
|
||||||
except AttributeError as exc:
|
|
||||||
LOG.warning(
|
|
||||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
|
||||||
)
|
|
||||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -21,7 +19,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Skipping test due to timeout.")
|
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using S2 Attn
|
Test case for Llama models using S2 Attn
|
||||||
|
|||||||
@@ -600,7 +600,6 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": None,
|
"pad_to_sequence_len": None,
|
||||||
"flash_attention": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -902,7 +901,6 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
"flash_attention": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -918,7 +916,6 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"flash_attention": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -931,7 +928,6 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
"flash_attention": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -945,7 +941,6 @@ class TestValidation(BaseValidation):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"flash_attention": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
|
|||||||
Reference in New Issue
Block a user