tp fixes
This commit is contained in:
@@ -14,6 +14,7 @@ from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
@@ -33,6 +34,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import is_distributed
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
@@ -102,6 +104,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
tensor_parallel: bool = field(
|
||||
default=False, metadata={"help": "Use tensor parallelism to train"}
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -246,6 +251,14 @@ class AxolotlTrainer(Trainer):
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.tensor_parallel:
|
||||
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||
else:
|
||||
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
return model
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -618,6 +631,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
||||
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from transformers import ( # noqa: F401
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_distributed
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -334,8 +333,6 @@ def load_model(
|
||||
low_cpu_mem_usage=True,
|
||||
offload_state_dict=True,
|
||||
)
|
||||
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
|
||||
Reference in New Issue
Block a user