This commit is contained in:
Wing Lian
2023-11-01 18:50:18 -04:00
parent 75e4fc2825
commit c4664ba8ee
2 changed files with 15 additions and 3 deletions

View File

@@ -14,6 +14,7 @@ from functools import partial
from pathlib import Path
from typing import Optional, Union
import tensor_parallel as tp
import torch
import transformers
from datasets import Dataset
@@ -33,6 +34,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
@@ -102,6 +104,9 @@ class AxolotlTrainingArguments(TrainingArguments):
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
tensor_parallel: bool = field(
default=False, metadata={"help": "Use tensor parallelism to train"}
)
class AxolotlTrainer(Trainer):
@@ -246,6 +251,14 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.tensor_parallel:
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
model = super()._wrap_model(model, training=training, dataloader=dataloader)
return model
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
@@ -618,6 +631,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)

View File

@@ -30,7 +30,6 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_distributed
LOG = logging.getLogger("axolotl")
@@ -334,8 +333,6 @@ def load_model(
low_cpu_mem_usage=True,
offload_state_dict=True,
)
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
config = AutoConfig.from_pretrained(
base_model,