diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 195c5c426..386505b94 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,6 +14,7 @@ from functools import partial from pathlib import Path from typing import Optional, Union +import tensor_parallel as tp import torch import transformers from datasets import Dataset @@ -33,6 +34,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader +from axolotl.utils.distributed import is_distributed from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup try: @@ -102,6 +104,9 @@ class AxolotlTrainingArguments(TrainingArguments): bench_source_max_len: int = field( default=2048, metadata={"help": "Maximum source sequence length for bench."} ) + tensor_parallel: bool = field( + default=False, metadata={"help": "Use tensor parallelism to train"} + ) class AxolotlTrainer(Trainer): @@ -246,6 +251,14 @@ class AxolotlTrainer(Trainer): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.tensor_parallel: + model = tp.tensor_parallel(model, distributed=is_distributed()) + model.hf_device_map = tp.infer_sharded_device_map(model) + else: + model = super()._wrap_model(model, training=training, dataloader=dataloader) + return model + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ @@ -618,6 +631,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.micro_batch_size training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps + training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8f696c12a..b7ecdf08a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -30,7 +30,6 @@ from transformers import ( # noqa: F401 from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_distributed LOG = logging.getLogger("axolotl") @@ -334,8 +333,6 @@ def load_model( low_cpu_mem_usage=True, offload_state_dict=True, ) - model = tp.tensor_parallel(model, distributed=is_distributed()) - model.hf_device_map = tp.infer_sharded_device_map(model) else: config = AutoConfig.from_pretrained( base_model,