From 70541145f169401540185dca5ddac94f94640fc6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 20 Dec 2024 21:43:33 -0500 Subject: [PATCH 1/8] adding test_datasets compat with pretraining_dataset (streaming) (#2206) [skip ci] --- src/axolotl/utils/data/sft.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 286e5f2d7..e2cb8f9f6 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -85,6 +85,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): processor=processor, ) else: + # Load streaming dataset if pretraining_dataset is given path = cfg.pretraining_dataset split = "train" name = None @@ -116,7 +117,18 @@ def prepare_dataset(cfg, tokenizer, processor=None): ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") + + # Load eval dataset (non-streaming) if specified eval_dataset = None + if cfg.test_datasets: + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, + ) + if cfg.dataset_exact_deduplication: LOG.info("Deduplication not available for pretrained datasets") From 307cf7c685eafe7c84f17ed871650755f589a884 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 20 Dec 2024 21:43:52 -0500 Subject: [PATCH 2/8] move the dataset loading from remote/disk to a shared function so we can re-use for RL (#2204) --- src/axolotl/utils/data/sft.py | 215 +----------------------------- src/axolotl/utils/data/shared.py | 222 +++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 210 deletions(-) create mode 100644 src/axolotl/utils/data/shared.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index e2cb8f9f6..3e784ca3e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,7 +3,7 @@ import functools import logging from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union from datasets import ( Dataset, @@ -12,8 +12,6 @@ from datasets import ( load_dataset, load_from_disk, ) -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import HFValidationError from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -42,6 +40,7 @@ from axolotl.prompters import ( UnsupportedPrompter, ) from axolotl.utils.data.pretraining import wrap_pretraining_dataset +from axolotl.utils.data.shared import load_dataset_w_config from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, md5, @@ -255,195 +254,9 @@ def load_tokenized_prepared_datasets( # pylint: disable=invalid-name for config_dataset in for_d_in_datasets(cfg_datasets): - ds: Optional[Union[Dataset, DatasetDict]] = None - ds_from_hub = False - ds_trust_remote_code = config_dataset.trust_remote_code - try: - # this is just a basic check to see if the path is a - # valid HF dataset that's loadable - load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=True, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=ds_trust_remote_code, - ) - ds_from_hub = True - except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): - pass - - ds_from_cloud = False - storage_options = {} - remote_file_system = None - if config_dataset.path.startswith("s3://"): - try: - import aiobotocore.session # type: ignore - import s3fs # type: ignore - except ImportError as exc: - raise ImportError( - "s3:// paths require aiobotocore and s3fs to be installed" - ) from exc - - # Takes credentials from ~/.aws/credentials for default profile - s3_session = aiobotocore.session.AioSession(profile="default") - storage_options = {"session": s3_session} - remote_file_system = s3fs.S3FileSystem(**storage_options) - elif config_dataset.path.startswith( - "gs://" - ) or config_dataset.path.startswith("gcs://"): - try: - import gcsfs # type: ignore - except ImportError as exc: - raise ImportError( - "gs:// or gcs:// paths require gcsfs to be installed" - ) from exc - - # gcsfs will use default credentials from the environment else anon - # https://gcsfs.readthedocs.io/en/latest/#credentials - storage_options = {"token": None} - remote_file_system = gcsfs.GCSFileSystem(**storage_options) - # TODO: Figure out how to get auth creds passed - # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): - # try: - # import adlfs - # except ImportError as exc: - # raise ImportError( - # "adl:// or abfs:// paths require adlfs to be installed" - # ) from exc - - # # Gen 1 - # storage_options = { - # "tenant_id": TENANT_ID, - # "client_id": CLIENT_ID, - # "client_secret": CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": ACCOUNT_NAME, - # "account_key": ACCOUNT_KEY, - # } - - # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) - try: - if remote_file_system and remote_file_system.exists( - config_dataset.path - ): - ds_from_cloud = True - except (FileNotFoundError, ConnectionError): - pass - - # prefer local dataset, even if hub exists - local_path = Path(config_dataset.path) - if local_path.exists(): - if local_path.is_dir(): - if config_dataset.data_files: - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=False, - split=None, - ) - else: - try: - ds = load_from_disk(config_dataset.path) - except FileNotFoundError: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - split=None, - ) - elif local_path.is_file(): - ds_type = get_ds_type(config_dataset) - - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - ) - else: - raise ValueError( - "unhandled dataset load: local path exists, but is neither a directory or a file" - ) - elif ds_from_hub: - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - data_files=config_dataset.data_files, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif ds_from_cloud and remote_file_system: - if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) - elif remote_file_system.isfile(config_dataset.path): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - ) - elif config_dataset.path.startswith("https://"): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - ) - else: - if isinstance(config_dataset.data_files, str): - fp = hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=config_dataset.data_files, - revision=config_dataset.revision, - ) - elif isinstance(config_dataset.data_files, list): - fp = [] - for file in config_dataset.data_files: - fp.append( - hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=file, - revision=config_dataset.revision, - ) - ) - else: - raise ValueError( - "data_files must be either a string or list of strings" - ) - ds = load_dataset( - "json", - name=config_dataset.name, - data_files=fp, - streaming=False, - split=None, - ) - if not ds: - raise ValueError("unhandled dataset load") + ds: Union[Dataset, DatasetDict] = load_dataset_w_config( + config_dataset, use_auth_token + ) d_base_type = d_prompt_style = None d_type = config_dataset.type @@ -513,24 +326,6 @@ def load_tokenized_prepared_datasets( return dataset, prompters -def get_ds_type(config_dataset: DictDefault): - """ - Get the dataset type from the path if it's not specified - """ - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" - return ds_type - - def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py new file mode 100644 index 000000000..d14496d96 --- /dev/null +++ b/src/axolotl/utils/data/shared.py @@ -0,0 +1,222 @@ +""" +dataset loading shared utils +""" +from pathlib import Path +from typing import Optional, Union + +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import HFValidationError + +from axolotl.utils.dict import DictDefault + + +def get_ds_type(config_dataset: DictDefault): + """ + Get the dataset type from the path if it's not specified + """ + ds_type = "json" + if config_dataset.ds_type: + ds_type = config_dataset.ds_type + elif ".parquet" in config_dataset.path: + ds_type = "parquet" + elif ".arrow" in config_dataset.path: + ds_type = "arrow" + elif ".csv" in config_dataset.path: + ds_type = "csv" + elif ".txt" in config_dataset.path: + ds_type = "text" + return ds_type + + +def load_dataset_w_config(config_dataset, auth_token): + # pylint: disable=invalid-name + ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name + ds_from_hub = False + ds_trust_remote_code = config_dataset.trust_remote_code + try: + # this is just a basic check to see if the path is a + # valid HF dataset that's loadable + load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=True, + token=auth_token, + revision=config_dataset.revision, + trust_remote_code=ds_trust_remote_code, + ) + ds_from_hub = True + except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): + pass + + ds_from_cloud = False + storage_options = {} + remote_file_system = None + if config_dataset.path.startswith("s3://"): + try: + import aiobotocore.session # type: ignore + import s3fs # type: ignore + except ImportError as exc: + raise ImportError( + "s3:// paths require aiobotocore and s3fs to be installed" + ) from exc + + # Takes credentials from ~/.aws/credentials for default profile + s3_session = aiobotocore.session.AioSession(profile="default") + storage_options = {"session": s3_session} + remote_file_system = s3fs.S3FileSystem(**storage_options) + elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith( + "gcs://" + ): + try: + import gcsfs # type: ignore + except ImportError as exc: + raise ImportError( + "gs:// or gcs:// paths require gcsfs to be installed" + ) from exc + + # gcsfs will use default credentials from the environment else anon + # https://gcsfs.readthedocs.io/en/latest/#credentials + storage_options = {"token": None} + remote_file_system = gcsfs.GCSFileSystem(**storage_options) + # TODO: Figure out how to get auth creds passed + # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): + # try: + # import adlfs + # except ImportError as exc: + # raise ImportError( + # "adl:// or abfs:// paths require adlfs to be installed" + # ) from exc + + # # Gen 1 + # storage_options = { + # "tenant_id": TENANT_ID, + # "client_id": CLIENT_ID, + # "client_secret": CLIENT_SECRET, + # } + # # Gen 2 + # storage_options = { + # "account_name": ACCOUNT_NAME, + # "account_key": ACCOUNT_KEY, + # } + + # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) + try: + if remote_file_system and remote_file_system.exists(config_dataset.path): + ds_from_cloud = True + except (FileNotFoundError, ConnectionError): + pass + + # prefer local dataset, even if hub exists + local_path = Path(config_dataset.path) + if local_path.exists(): + if local_path.is_dir(): + if config_dataset.data_files: + ds_type = get_ds_type(config_dataset) + ds = load_dataset( # pylint: disable=invalid-name + ds_type, + name=config_dataset.name, + data_files=config_dataset.data_files, + streaming=False, + split=None, + ) + else: + try: + ds = load_from_disk( + config_dataset.path + ) # pylint: disable=invalid-name + except FileNotFoundError: + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + split=None, + ) + elif local_path.is_file(): + ds_type = get_ds_type(config_dataset) + + ds = load_dataset( # pylint: disable=invalid-name + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + ) + else: + raise ValueError( + "unhandled dataset load: local path exists, but is neither a directory or a file" + ) + elif ds_from_hub: + load_ds_kwargs = {} + if config_dataset.split: + load_ds_kwargs["split"] = config_dataset.split + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + data_files=config_dataset.data_files, + token=auth_token, + revision=config_dataset.revision, + trust_remote_code=config_dataset.trust_remote_code, + **load_ds_kwargs, + ) + elif ds_from_cloud and remote_file_system: + if remote_file_system.isdir(config_dataset.path): + ds = load_from_disk( + config_dataset.path, + storage_options=storage_options, + ) + elif remote_file_system.isfile(config_dataset.path): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, + ) + elif config_dataset.path.startswith("https://"): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, + ) + else: + if isinstance(config_dataset.data_files, str): + fp = hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=config_dataset.data_files, + revision=config_dataset.revision, + ) + elif isinstance(config_dataset.data_files, list): + fp = [] + for file in config_dataset.data_files: + fp.append( + hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=file, + revision=config_dataset.revision, + ) + ) + else: + raise ValueError("data_files must be either a string or list of strings") + ds = load_dataset( + "json", + name=config_dataset.name, + data_files=fp, + streaming=False, + split=None, + ) + if not ds: + raise ValueError("unhandled dataset load") + + return ds From 2312caaa9870c54d436aa8bc91802005f102203d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 17:38:33 -0500 Subject: [PATCH 3/8] GC every n steps (#2209) --- src/axolotl/core/trainer_builder.py | 3 +++ src/axolotl/utils/callbacks/__init__.py | 15 +++++++++++++++ .../utils/config/models/input/v0_4_1/__init__.py | 2 ++ 3 files changed, 20 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee19536..fffddac81 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -56,6 +56,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, + GCCallback, GPUStatsCallback, LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, @@ -1452,6 +1453,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + if self.cfg.gc_steps: + callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) callbacks.append(SaveModelCallback()) return callbacks diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 641c9b162..f1b459b6b 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import gc import logging import math import os @@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback): ): control.should_save = True return control + + +class GCCallback(TrainerCallback): + """Callback to garbage collect torch cache""" + + def __init__(self, gc_steps=None): + self.gc_steps = gc_steps + + def on_step_end( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + if state.global_step % self.gc_steps == 0: + torch.cuda.empty_cache() + gc.collect() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5ddf04811..c704be800 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -666,6 +666,8 @@ class AxolotlInputConfig( loss_watchdog_threshold: Optional[float] = None loss_watchdog_patience: Optional[int] = None + gc_steps: Optional[int] = None + bf16: Optional[Union[Literal["auto"], bool]] = "auto" fp16: Optional[bool] = None bfloat16: Optional[bool] = None # for non-AMP cases From 3742deb1ded554d5b5b61db97c979c12af2cb354 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 22 Dec 2024 12:11:39 -0500 Subject: [PATCH 4/8] add deepspeed example with torch compile enabled (#2212) [skip ci] --- deepspeed_configs/zero1_torch_compile.json | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 deepspeed_configs/zero1_torch_compile.json diff --git a/deepspeed_configs/zero1_torch_compile.json b/deepspeed_configs/zero1_torch_compile.json new file mode 100644 index 000000000..b88451392 --- /dev/null +++ b/deepspeed_configs/zero1_torch_compile.json @@ -0,0 +1,27 @@ +{ + "zero_optimization": { + "stage": 1, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "compile": { + "disable": false, + "backend": "inductor" + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} From d852d7af7a73fd43f0727dcda43f4c00d2f99e29 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Dec 2024 07:48:41 -0500 Subject: [PATCH 5/8] inference - don't default w accelerate, fix base model (#2216) [skip ci] --- src/axolotl/cli/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 6f883a2ac..14803e43b 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs): @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( "--accelerate/--no-accelerate", - default=True, + default=False, help="Use accelerate launch for multi-GPU inference", ) @click.option( @@ -124,7 +124,7 @@ def inference( if lora_model_dir: kwargs["lora_model_dir"] = lora_model_dir if base_model: - kwargs["output_dir"] = base_model + kwargs["base_model"] = base_model if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] From e0a2eb2ebd45f3cf73415350f5d7a8ec0b860b7c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Dec 2024 09:08:28 -0500 Subject: [PATCH 6/8] fix untrained tokens if specified explicitly from a list (#2210) --- requirements.txt | 2 +- src/axolotl/train.py | 16 +++++++++++++++- .../utils/config/models/input/v0_4_1/__init__.py | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 61e1a9f90..283b5cc2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 -axolotl-contribs-lgpl==0.0.1b2 +axolotl-contribs-lgpl==0.0.2 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b09..a74ecc2ec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,5 +1,6 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import inspect import os import signal import sys @@ -126,7 +127,20 @@ def train( ) if cfg.fix_untrained_tokens: - fix_untrained_tokens(model, tokenizer, train_dataset) + # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + sig = inspect.signature(fix_untrained_tokens) + # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + if "token_ids_to_fix" in sig.parameters and isinstance( + cfg.fix_untrained_tokens, list + ): + fix_untrained_tokens( + model, + tokenizer, + train_dataset, + token_ids_to_fix=cfg.fix_untrained_tokens, + ) + else: + fix_untrained_tokens(model, tokenizer, train_dataset) if cfg.local_rank == 0: model.save_pretrained( str(Path(cfg.output_dir)), safe_serialization=safe_serialization diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c704be800..0781c6798 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -794,7 +794,7 @@ class AxolotlInputConfig( chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None - fix_untrained_tokens: Optional[bool] = None + fix_untrained_tokens: Optional[Union[int, List[int]]] = None # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None From 7a38dbe67444fc1224713b68084a04c72b2f6fc9 Mon Sep 17 00:00:00 2001 From: NJordan72 Date: Tue, 24 Dec 2024 16:18:50 -0500 Subject: [PATCH 7/8] fix: allow trainer builder to use custom jinja chat template (#2219) * fix: allow trainer builder to use custom jinja chat template * chore: use get_chat_template_from_config Co-authored-by: Chirag Jain * fix: swap imports --------- Co-authored-by: Chirag Jain --- src/axolotl/core/trainer_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fffddac81..e81740399 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -68,7 +68,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1834,8 +1834,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template, + training_arguments_kwargs["chat_template"] = get_chat_template_from_config( + cfg=self.cfg, tokenizer=self.tokenizer, ) From 3915abee4cba364a83f18cc4b8c2bb571cce3bac Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 31 Dec 2024 15:22:18 -0500 Subject: [PATCH 8/8] make sure padding is labeled as -100 for pretraining (#2227) --- src/axolotl/utils/data/pretraining.py | 44 +++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 16f38218c..f493db70e 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -28,8 +28,10 @@ def encode_pretraining( ) # Convert to PyTorch tensors input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] new_input_ids = [] + new_labels = [] new_attention_mask = [] # Append EOS and PAD tokens to input_ids, and correct attention_mask for i, _ in enumerate(input_ids): @@ -40,22 +42,34 @@ def encode_pretraining( ), dim=0, ) + targets[i] = torch.cat( + ( + targets[i], + torch.tensor([tokenizer.eos_token_id, -100]), + ), + dim=0, + ) attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) # Concatenate tokens so that their lengths are less than max_tokens buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) - for ids, mask in zip(input_ids, attention_mask): + for ids, labels, mask in zip(input_ids, targets, attention_mask): if buffer_input_ids.numel() == max_tokens: new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) elif buffer_input_ids.numel() + ids.numel() <= max_tokens: buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) else: buffer_input_ids = torch.cat( @@ -69,6 +83,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -81,11 +106,14 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) if buffer_input_ids.numel() > 0: # for any leftover tokens @@ -101,6 +129,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -113,11 +152,12 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) ret = { "input_ids": [seq.tolist() for seq in new_input_ids], - "labels": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_labels], "attention_mask": [seq.tolist() for seq in new_attention_mask], }