Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
05f7034288 use deterministic seed for random LISA layers 2024-04-04 18:16:55 -07:00
6 changed files with 15 additions and 29 deletions

View File

@@ -38,8 +38,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
for param in layer.parameters():
param.requires_grad = False
def on_train_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
self.switch_active_layers(state)
def on_step_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1:
self.switch_active_layers()
if state.global_step % self.step_interval == 0:
self.switch_active_layers(state)
def switch_active_layers(self):
def switch_active_layers(self, state):
# First, disable gradients for all layers
self.freeze_all_layers()
deterministic_seed = state.global_step
np.random.seed(deterministic_seed)
# Randomly select n_layers to activate
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False
range(self.total_layers),
self.n_layers,
replace=False,
)
LOG.info(
f"Activating layers at indices: {self.active_layers_indices} for the next steps."

View File

@@ -23,7 +23,6 @@ def chat_templates(user_choice: str):
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
}
if user_choice in templates:

View File

@@ -1,7 +1,6 @@
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
@@ -141,7 +140,6 @@ class ChatTemplate(str, Enum):
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -656,20 +654,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
raise ValueError(
"sample_packing requires flash_attention or sdp_attention to be set to true"
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):

View File

@@ -459,7 +459,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"):
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32

View File

@@ -600,7 +600,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
@@ -902,7 +901,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -918,7 +916,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg
@@ -931,7 +928,6 @@ class TestValidation(BaseValidation):
{
"sample_packing": False,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -945,7 +941,6 @@ class TestValidation(BaseValidation):
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg