fix for accelerate env var for auto bf16, add new base image and expand torch_cuda_arch_list support (#1413)

This commit is contained in:
Wing Lian
2024-03-26 13:46:19 -07:00
committed by GitHub
parent e07347b188
commit da265dd796
2 changed files with 14 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ import torch.cuda
from accelerate.logging import get_logger
from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
@@ -324,6 +325,11 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair"]: