diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c5f20589..9f2ceac56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,8 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - id: no-commit-to-branch + args: ['--branch', 'main'] - repo: https://github.com/psf/black rev: 23.3.0 hooks: diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index 69ce143bb..c0bb266d2 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" -ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux index efeffef8e..3e59d4119 100644 --- a/docker/Dockerfile-cloud-no-tmux +++ b/docker/Dockerfile-cloud-no-tmux @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" -ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index f7d0cac82..28d13c987 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -54,6 +54,14 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt {"conversations": [{"from": "...", "value": "..."}]} ``` +## sharegpt.load_ultrachat + +conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'. + +```{.json filename="data.jsonl"} +{"messages": [{"user": "...", "assistant": "..."}]} +``` + ## sharegpt_jokes creates a chat where bot is asked to tell a joke, then explain why the joke is funny diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 3fcc4d2a9..3a6981ee0 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -43,7 +43,6 @@ }, "outputs": [], "source": [ - "!pip install torch==\"2.1.2\"\n", "!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n", "!pip install flash-attn==\"2.5.0\"\n", "!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\"" diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index 21d32604c..4acad5999 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -74,3 +74,5 @@ deepspeed: weight_decay: 0.0 fsdp: fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 385b7f91d..6eeec01c9 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3.1-405B +base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16 tokenizer_type: AutoTokenizer load_in_4bit: true @@ -10,10 +10,11 @@ datasets: dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: ./outputs/out/qlora-llama3_1-405b +save_safetensors: true adapter: qlora -sequence_len: 1024 +sequence_len: 2048 sample_packing: true pad_to_sequence_len: true @@ -25,7 +26,7 @@ lora_target_linear: true gradient_accumulation_steps: 4 micro_batch_size: 1 -num_epochs: 4 +num_epochs: 2 optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.00001 diff --git a/requirements.txt b/requirements.txt index 5825ee190..fdcae107c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.11.1 -transformers==4.43.3 +peft==0.12.0 +transformers==4.43.4 tokenizers==0.19.1 -bitsandbytes==0.43.1 +bitsandbytes==0.43.3 accelerate==0.32.0 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 5966d5931..a05ee84e9 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_optim_env +from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): prepare_optim_env(cfg) + prepare_opinionated_env(cfg) + normalize_config(cfg) normalize_cfg_datasets(cfg) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 7610b335a..827a63c07 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -11,4 +11,5 @@ MOE_ARCH_BLOCK = { ], "mixtral": "MixtralSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock", + "deepseek_v2": "DeepseekV2MoE", } diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ff4804b10..4e8b36905 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -242,6 +242,12 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate optimizer to the HF trainer" }, ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) @dataclass @@ -318,7 +324,23 @@ class SchedulerMixin(Trainer): # fmt: off if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition # fmt: on - if use_cosine_quadratic: + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif use_cosine_quadratic: if use_cosine_min_lr: LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") @@ -876,37 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer): return lm_loss -class OneCycleLRSchedulerTrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "onecycle"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - pct_start=pct_start, - div_factor=6, - ) - - return self.lr_scheduler - - class ReLoRATrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler @@ -1190,10 +1181,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): - if self.cfg.lr_scheduler == "one_cycle" and ( - self.cfg.fsdp or self.cfg.adapter == "qlora" - ): - return OneCycleLRSchedulerTrainer if self.cfg.relora_steps: return ReLoRATrainer if self.cfg.model_config_type == "mamba": @@ -1243,7 +1230,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: - training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) + training_arguments_kwargs["fsdp_config"] = { + k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items() + } if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True @@ -1441,12 +1430,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "loraplus_lr_embedding" ] = self.cfg.loraplus_lr_embedding - training_arguments_kwargs["lr_scheduler_type"] = ( - self.cfg.lr_scheduler - if self.cfg.lr_scheduler - and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") - else "cosine" - ) + if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: + training_arguments_kwargs["lr_scheduler_type"] = "cosine" + training_arguments_kwargs[ + "alternate_lr_scheduler_type" + ] = self.cfg.lr_scheduler + else: + training_arguments_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" + ) training_arguments_kwargs["lr_scheduler_kwargs"] = ( self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index a2ce0e64f..904352010 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -25,12 +25,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ ] -def patch_for_multipack(model_type, model_name=None): +def patch_for_multipack(model_type, model_name=None, is_remote_code=False): if model_type == "gemmoe": patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") - elif hasattr(transformers, "modeling_flash_attention_utils"): + elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index c79adbd5e..3363bcfc2 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -10,8 +10,8 @@ from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config # Configure the logger -logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") +LOG.setLevel(logging.INFO) class ChatTemplatePrompter(Prompter): @@ -359,7 +359,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train"), + "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), "train_on_eos": ds_cfg.get("train_on_eos", "last"), } diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index ac00a82ab..42521eca9 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -23,6 +23,7 @@ _TEMPLATES = { "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 26c02c7e2..bc4c79a6c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -39,6 +39,7 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name mistral = "mistral" # pylint: disable=invalid-name tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name @@ -258,6 +259,12 @@ class LoraConfig(BaseModel): peft_use_rslora: Optional[bool] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None + qlora_sharded_model_loading: Optional[bool] = Field( + default=False, + metadata={ + "help": "load qlora model in sharded format for FSDP using answer.ai technique." + }, + ) lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None bnb_config_kwargs: Optional[Dict[str, Any]] = None @@ -395,7 +402,7 @@ class HyperparametersConfig(BaseModel): }, ) torchdistx_path: Optional[str] = None - lr_scheduler: Optional[SchedulerType] = "cosine" + lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine" lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None cosine_min_lr_ratio: Optional[float] = None @@ -978,6 +985,8 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_eval_packing(cls, data): + # TODO also should check test_datasets and val_set_size as we can skip + # if there are no eval datasets/splits if ( data.get("sample_packing") and data.get("eval_table_size") diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 2e923057d..1b6df1cde 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -160,8 +160,12 @@ def load_tokenized_prepared_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: + LOG.info( + f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." + ) dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", + cfg.push_dataset_to_hub, + ds_hash, token=use_auth_token, ) dataset = dataset[split] @@ -181,7 +185,14 @@ def load_tokenized_prepared_datasets( dataset = load_from_disk(str(prepared_ds_path)) LOG.info("Prepared dataset loaded from disk...") else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") + if cfg.push_dataset_to_hub: + LOG.info("Unable to find prepared dataset in Huggingface hub") + if cfg.is_preprocess: + LOG.info( + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." + ) + else: + LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info("Loading raw datasets...") if not cfg.is_preprocess: LOG.warning( @@ -433,10 +444,12 @@ def load_tokenized_prepared_datasets( dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: LOG.info( - f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." ) dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True + cfg.push_dataset_to_hub, + ds_hash, + private=True, ) return dataset, prompters diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 4444a20c9..3a559f5f5 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -153,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name if is_main_process(): value_scalar = fn() value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() + value_scalar, device=torch.cuda.current_device(), dtype=torch.float32 + ) else: value_tensor = torch.tensor( - 0.0, device=torch.cuda.current_device() + 0.0, device=torch.cuda.current_device(), dtype=torch.float32 ) # Placeholder tensor # Broadcast the tensor to all processes. diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index 65f23b9e0..9ed7ae471 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -13,6 +13,7 @@ from fastcore.parallel import parallel from torch import Tensor, nn from tqdm import tqdm from transformers import AutoModelForCausalLM +from transformers.quantizers import AutoHfQuantizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub @@ -173,6 +174,7 @@ def load_sharded_model_quant( low_memory=True, verbose=False, loading_workers=2, + quantization_config=None, ): with init_empty_weights(): model = AutoModelForCausalLM.from_config( @@ -186,15 +188,26 @@ def load_sharded_model_quant( compute_dtype=compute_dtype, quant_type="nf4", quant_storage=quant_storage, + compress_statistics=True, # bnb_4bit_use_double_quant + skip_modules=[ + "lm_head", + "embed_out", + ], ) else: # this is the more common case with HF transformers + # TODO can we detect the model arch and dynamically set skip_modules model.model = _replace_linear( model.model, Linear4bit, compute_dtype=compute_dtype, quant_type="nf4", quant_storage=quant_storage, + compress_statistics=True, # bnb_4bit_use_double_quant + skip_modules=[ + "lm_head", + "embed_out", + ], ) model.is_loaded_in_4bit = True @@ -251,6 +264,11 @@ def load_sharded_model_quant( quant_method=quant_method, ) + # these attributes are needed to inform transformers/peft of the quantization + model.is_quantized = True + model.quantization_method = "bitsandbytes" + model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config) + if cfg.local_rank == 0 and verbose: print(f"Loaded model weights in {time.time()-start:.3f} seconds") # cleanup any extra memory usage from parallel loading diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52a85c2ac..9015bdd97 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -351,7 +351,11 @@ def load_model( and cfg.flash_attention and cfg.sample_packing ): - patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) + patch_for_multipack( + cfg.model_config_type, + model_name=cfg.base_model, + is_remote_code=cfg.trust_remote_code, + ) if cfg.is_llama_derived_model: from axolotl.monkeypatch.llama_attn_hijack_flash import ( @@ -627,14 +631,21 @@ def load_model( elif ( qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and cfg.model_config_type == "dbrx" + and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading) ): quant_storage = cfg.torch_dtype + quantization_config = hasattr( + model_config, "quantization_config" + ) and getattr(model_config, "quantization_config") + quantization_config = ( + quantization_config or model_kwargs["quantization_config"] + ) model = load_sharded_model_quant( base_model, model_config, cfg, quant_storage=quant_storage, + quantization_config=quantization_config, ) skip_move_to_device = True elif ( @@ -1014,7 +1025,7 @@ def load_lora(model, cfg, inference=False, config_only=False): if cfg.lora_target_linear: linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(linear_names)}") + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") lora_target_modules = list(set(lora_target_modules + linear_names)) lora_config_kwargs = {} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index bb9624051..02234d8b7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first +from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger("axolotl") @@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): sequence_len=cfg.sequence_len, min_sequence_len=cfg.min_sample_len or 2, ) - with zero_first(is_main_process()): - if cfg.is_preprocess: - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if cfg.model_config_type == "mamba": - LOG.info("dropping attention_mask column") - train_dataset = train_dataset.remove_columns("attention_mask") - if eval_dataset: - eval_dataset = eval_dataset.remove_columns("attention_mask") + if cfg.is_preprocess: + min_input_len = np.min(get_dataset_lengths(train_dataset)) + LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) + max_input_len = np.max(get_dataset_lengths(train_dataset)) + LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if cfg.model_config_type == "falcon": - LOG.info("dropping token_type_ids column if it exists") - if "token_type_ids" in train_dataset.column_names: - train_dataset = train_dataset.remove_columns("token_type_ids") - if eval_dataset and "token_type_ids" in eval_dataset.column_names: - eval_dataset = eval_dataset.remove_columns("token_type_ids") + if cfg.model_config_type == "mamba": + LOG.info("dropping attention_mask column") + train_dataset = train_dataset.remove_columns("attention_mask") + if eval_dataset: + eval_dataset = eval_dataset.remove_columns("attention_mask") - train_dataset = train_dataset.filter( + if cfg.model_config_type == "falcon": + LOG.info("dropping token_type_ids column if it exists") + if "token_type_ids" in train_dataset.column_names: + train_dataset = train_dataset.remove_columns("token_type_ids") + if eval_dataset and "token_type_ids" in eval_dataset.column_names: + eval_dataset = eval_dataset.remove_columns("token_type_ids") + + train_dataset = train_dataset.filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + if eval_dataset: + eval_dataset = eval_dataset.filter( drop_long, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", ) - if eval_dataset: - eval_dataset = eval_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - if cfg.group_by_length: - train_dataset = train_dataset.map( - add_length, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Group By Length", - ) + if cfg.group_by_length: + train_dataset = train_dataset.map( + add_length, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Group By Length", + ) - if cfg.use_pose: - pose_kwargs = {} - if cfg.pose_num_chunks is not None: - pose_kwargs["chunks"] = cfg.pose_num_chunks - pose_fn = partial( - add_pose_position_ids, - max_context_len=cfg.pose_max_context_len, - split_on_token_ids=cfg.pose_split_on_token_ids, - **pose_kwargs, - ) - train_dataset = train_dataset.map( - pose_fn, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (PoSE)", - ) - train_dataset = train_dataset.sort("sequence_len") - if cfg.eval_sample_packing is not False: - if eval_dataset: - eval_dataset = eval_dataset.map( - pose_fn, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (PoSE)", - ) - elif cfg.sample_packing: - train_dataset = train_dataset.map( - add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", - ) - if cfg.eval_sample_packing is not False: - if eval_dataset: - eval_dataset = eval_dataset.map( - add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", - ) + if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.pose_max_context_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + train_dataset = train_dataset.sort("sequence_len") + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: + train_dataset = train_dataset.map( + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (Sample Packing)", + ) + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (Sample Packing)", + ) return train_dataset, eval_dataset @@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed - if cfg.bf16: - os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" - elif cfg.fp16: - os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: @@ -444,6 +440,12 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" +def prepare_opinionated_env(cfg): + if cfg.qlora_sharded_model_loading: + # model loading is forked after the tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)