Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-08-06 21:23:14 +05:30
committed by GitHub
20 changed files with 204 additions and 145 deletions

View File

@@ -8,6 +8,8 @@ repos:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 23.3.0
hooks: hooks:

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -54,6 +54,14 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
## sharegpt.load_ultrachat
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```
## sharegpt_jokes ## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny creates a chat where bot is asked to tell a joke, then explain why the joke is funny

View File

@@ -43,7 +43,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n", "!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n", "!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\"" "!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""

View File

@@ -74,3 +74,5 @@ deepspeed:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3.1-405B base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
load_in_4bit: true load_in_4bit: true
@@ -10,10 +10,11 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora adapter: qlora
sequence_len: 1024 sequence_len: 2048
sample_packing: true sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
@@ -25,7 +26,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 2
optimizer: adamw_torch optimizer: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.00001 learning_rate: 0.00001

View File

@@ -1,9 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.12.0
transformers==4.43.3 transformers==4.43.4
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.3
accelerate==0.32.0 accelerate==0.32.0
deepspeed==0.14.4 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3

View File

@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg) normalize_config(cfg)
normalize_cfg_datasets(cfg) normalize_cfg_datasets(cfg)

View File

@@ -11,4 +11,5 @@ MOE_ARCH_BLOCK = {
], ],
"mixtral": "MixtralSparseMoeBlock", "mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
} }

View File

@@ -242,6 +242,12 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate optimizer to the HF trainer" "help": "workaround to pass an alternate optimizer to the HF trainer"
}, },
) )
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
@dataclass @dataclass
@@ -318,7 +324,23 @@ class SchedulerMixin(Trainer):
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on # fmt: on
if use_cosine_quadratic: if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -876,37 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
return lm_loss return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer): class ReLoRATrainer(AxolotlTrainer):
""" """
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
@@ -1190,10 +1181,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks return callbacks
def _get_trainer_cls(self): def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps: if self.cfg.relora_steps:
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
@@ -1243,7 +1230,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora": if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True training_arguments_kwargs["qlora"] = True
@@ -1441,12 +1430,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"loraplus_lr_embedding" "loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = ( if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
self.cfg.lr_scheduler training_arguments_kwargs["lr_scheduler_type"] = "cosine"
if self.cfg.lr_scheduler training_arguments_kwargs[
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") "alternate_lr_scheduler_type"
else "cosine" ] = self.cfg.lr_scheduler
) else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = ( training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )

View File

@@ -25,12 +25,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
] ]
def patch_for_multipack(model_type, model_name=None): def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe": if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2": elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils"): elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )

View File

@@ -10,8 +10,8 @@ from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger # Configure the logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
@@ -359,7 +359,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = { strategy_params = {
"train_on_inputs": cfg.train_on_inputs, "train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len, "sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train"), "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "last"), "train_on_eos": ds_cfg.get("train_on_eos", "last"),
} }

View File

@@ -23,6 +23,7 @@ _TEMPLATES = {
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
} }

View File

@@ -39,6 +39,7 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
mistral = "mistral" # pylint: disable=invalid-name mistral = "mistral" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
@@ -258,6 +259,12 @@ class LoraConfig(BaseModel):
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None bnb_config_kwargs: Optional[Dict[str, Any]] = None
@@ -395,7 +402,7 @@ class HyperparametersConfig(BaseModel):
}, },
) )
torchdistx_path: Optional[str] = None torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = "cosine" lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None cosine_min_lr_ratio: Optional[float] = None
@@ -978,6 +985,8 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_eval_packing(cls, data): def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if ( if (
data.get("sample_packing") data.get("sample_packing")
and data.get("eval_table_size") and data.get("eval_table_size")

View File

@@ -160,8 +160,12 @@ def load_tokenized_prepared_datasets(
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info(
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset = load_dataset( dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", cfg.push_dataset_to_hub,
ds_hash,
token=use_auth_token, token=use_auth_token,
) )
dataset = dataset[split] dataset = dataset[split]
@@ -181,7 +185,14 @@ def load_tokenized_prepared_datasets(
dataset = load_from_disk(str(prepared_ds_path)) dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...") LOG.info("Prepared dataset loaded from disk...")
else: else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") if cfg.push_dataset_to_hub:
LOG.info("Unable to find prepared dataset in Huggingface hub")
if cfg.is_preprocess:
LOG.info(
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
)
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...") LOG.info("Loading raw datasets...")
if not cfg.is_preprocess: if not cfg.is_preprocess:
LOG.warning( LOG.warning(
@@ -433,10 +444,12 @@ def load_tokenized_prepared_datasets(
dataset.save_to_disk(str(prepared_ds_path)) dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
) )
dataset.push_to_hub( dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True cfg.push_dataset_to_hub,
ds_hash,
private=True,
) )
return dataset, prompters return dataset, prompters

View File

@@ -153,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
if is_main_process(): if is_main_process():
value_scalar = fn() value_scalar = fn()
value_tensor = torch.tensor( value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device() value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
).float() )
else: else:
value_tensor = torch.tensor( value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device() 0.0, device=torch.cuda.current_device(), dtype=torch.float32
) # Placeholder tensor ) # Placeholder tensor
# Broadcast the tensor to all processes. # Broadcast the tensor to all processes.

View File

@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
from torch import Tensor, nn from torch import Tensor, nn
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
low_memory=True, low_memory=True,
verbose=False, verbose=False,
loading_workers=2, loading_workers=2,
quantization_config=None,
): ):
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
quant_type="nf4", quant_type="nf4",
quant_storage=quant_storage, quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
) )
else: else:
# this is the more common case with HF transformers # this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear( model.model = _replace_linear(
model.model, model.model,
Linear4bit, Linear4bit,
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
quant_type="nf4", quant_type="nf4",
quant_storage=quant_storage, quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
) )
model.is_loaded_in_4bit = True model.is_loaded_in_4bit = True
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
quant_method=quant_method, quant_method=quant_method,
) )
# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds") print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading # cleanup any extra memory usage from parallel loading

View File

@@ -351,7 +351,11 @@ def load_model(
and cfg.flash_attention and cfg.flash_attention
and cfg.sample_packing and cfg.sample_packing
): ):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) patch_for_multipack(
cfg.model_config_type,
model_name=cfg.base_model,
is_remote_code=cfg.trust_remote_code,
)
if cfg.is_llama_derived_model: if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -627,14 +631,21 @@ def load_model(
elif ( elif (
qlora_fsdp qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx" and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
): ):
quant_storage = cfg.torch_dtype quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant( model = load_sharded_model_quant(
base_model, base_model,
model_config, model_config,
cfg, cfg,
quant_storage=quant_storage, quant_storage=quant_storage,
quantization_config=quantization_config,
) )
skip_move_to_device = True skip_move_to_device = True
elif ( elif (
@@ -1014,7 +1025,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
if cfg.lora_target_linear: if cfg.lora_target_linear:
linear_names = find_all_linear_names(model) linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules = list(set(lora_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {} lora_config_kwargs = {}

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl") LOG = get_logger("axolotl")
@@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
sequence_len=cfg.sequence_len, sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2, min_sequence_len=cfg.min_sample_len or 2,
) )
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba": if cfg.is_preprocess:
LOG.info("dropping attention_mask column") min_input_len = np.min(get_dataset_lengths(train_dataset))
train_dataset = train_dataset.remove_columns("attention_mask") LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
if eval_dataset: max_input_len = np.max(get_dataset_lengths(train_dataset))
eval_dataset = eval_dataset.remove_columns("attention_mask") LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "falcon": if cfg.model_config_type == "mamba":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping attention_mask column")
if "token_type_ids" in train_dataset.column_names: train_dataset = train_dataset.remove_columns("attention_mask")
train_dataset = train_dataset.remove_columns("token_type_ids") if eval_dataset:
if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("attention_mask")
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter( if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long, drop_long,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
) )
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if cfg.group_by_length: if cfg.group_by_length:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_length, add_length,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length", desc="Group By Length",
) )
if cfg.use_pose: if cfg.use_pose:
pose_kwargs = {} pose_kwargs = {}
if cfg.pose_num_chunks is not None: if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial( pose_fn = partial(
add_pose_position_ids, add_pose_position_ids,
max_context_len=cfg.pose_max_context_len, max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids, split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs, **pose_kwargs,
) )
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
train_dataset = train_dataset.sort("sequence_len") train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False: if cfg.eval_sample_packing is not False:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)", desc="Add position_id column (Sample Packing)",
) )
if cfg.eval_sample_packing is not False: if cfg.eval_sample_packing is not False:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)", desc="Add position_id column (Sample Packing)",
) )
return train_dataset, eval_dataset return train_dataset, eval_dataset
@@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage: if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3: if stage == 3:
@@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)