various batch of fixes (#1785)

* various batch of fixes

* more tweaks

* fix autoawq requirement for torch flexibility

* simplify conditionals

* multi-node fixes wip

* bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
Wing Lian
2024-07-28 07:25:54 -04:00
committed by GitHub
parent 22680913f3
commit 94ba93259f
11 changed files with 253 additions and 99 deletions

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code
if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
...
elif (
cfg.dataset_prepared_path
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,