Compare commits
1 Commits
grouped_lr
...
djsaunde-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fae6b2df10 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,7 +1,6 @@
|
|||||||
**/axolotl.egg-info
|
**/axolotl.egg-info
|
||||||
configs
|
configs
|
||||||
last_run_prepared/
|
last_run_prepared/
|
||||||
outputs
|
|
||||||
.vscode
|
.vscode
|
||||||
_site/
|
_site/
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
{
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 1,
|
|
||||||
"overlap_comm": true
|
|
||||||
},
|
|
||||||
"bf16": {
|
|
||||||
"enabled": "auto"
|
|
||||||
},
|
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"compile": {
|
|
||||||
"disable": false,
|
|
||||||
"backend": "inductor"
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"wall_clock_breakdown": false
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
---
|
|
||||||
title: Learning Rate Groups
|
|
||||||
description: "Setting different learning rates by module name"
|
|
||||||
---
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
|
||||||
modules in a model.
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
lr_groups:
|
|
||||||
- name: o_proj
|
|
||||||
modules:
|
|
||||||
- self_attn.o_proj.weight
|
|
||||||
lr: 1e-6
|
|
||||||
- name: q_proj
|
|
||||||
modules:
|
|
||||||
- model.layers.2.self_attn.q_proj.weight
|
|
||||||
lr: 1e-5
|
|
||||||
|
|
||||||
learning_rate: 2e-5
|
|
||||||
```
|
|
||||||
|
|
||||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
|
||||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
|
||||||
self attention `q_proj` module.
|
|
||||||
@@ -56,7 +56,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GCCallback,
|
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
@@ -244,10 +243,6 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
)
|
)
|
||||||
lr_groups: Optional[list[dict]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
|
||||||
)
|
|
||||||
embedding_lr: Optional[float] = field(
|
embedding_lr: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
@@ -466,96 +461,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
if lr_groups_lookup and any(
|
|
||||||
group_modules in name for group_modules in lr_groups_lookup
|
|
||||||
):
|
|
||||||
lr_group_module = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
][0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in [
|
not in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
@@ -569,13 +479,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
for name, param in opt_model.named_parameters():
|
||||||
)
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
@@ -592,7 +548,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif (
|
elif (
|
||||||
self.args.embedding_lr_scale is not None
|
self.args.embedding_lr_scale is not None
|
||||||
or self.args.embedding_lr is not None
|
or self.args.embedding_lr is not None
|
||||||
or self.args.lr_groups is not None
|
|
||||||
):
|
):
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
@@ -1497,8 +1452,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
if self.cfg.gc_steps:
|
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
|
||||||
callbacks.append(SaveModelCallback())
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
@@ -1808,7 +1761,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -843,17 +842,3 @@ class SaveModelCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class GCCallback(TrainerCallback):
|
|
||||||
"""Callback to garbage collect torch cache"""
|
|
||||||
|
|
||||||
def __init__(self, gc_steps=None):
|
|
||||||
self.gc_steps = gc_steps
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if state.global_step % self.gc_steps == 0:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|||||||
@@ -145,14 +145,6 @@ class UserDefinedPrompterType(BaseModel):
|
|||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class LrGroup(BaseModel):
|
|
||||||
"""Custom learning rate group configuration"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
modules: List[str]
|
|
||||||
lr: float
|
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
class SFTDataset(BaseModel):
|
||||||
"""SFT configuration subset"""
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
@@ -474,7 +466,6 @@ class HyperparametersConfig(BaseModel):
|
|||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
cosine_constant_lr_ratio: Optional[float] = None
|
cosine_constant_lr_ratio: Optional[float] = None
|
||||||
lr_div_factor: Optional[float] = None
|
lr_div_factor: Optional[float] = None
|
||||||
lr_groups: Optional[List[LrGroup]] = None
|
|
||||||
|
|
||||||
adam_epsilon: Optional[float] = None
|
adam_epsilon: Optional[float] = None
|
||||||
adam_beta1: Optional[float] = None
|
adam_beta1: Optional[float] = None
|
||||||
@@ -675,8 +666,6 @@ class AxolotlInputConfig(
|
|||||||
loss_watchdog_threshold: Optional[float] = None
|
loss_watchdog_threshold: Optional[float] = None
|
||||||
loss_watchdog_patience: Optional[int] = None
|
loss_watchdog_patience: Optional[int] = None
|
||||||
|
|
||||||
gc_steps: Optional[int] = None
|
|
||||||
|
|
||||||
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
|
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||||
fp16: Optional[bool] = None
|
fp16: Optional[bool] = None
|
||||||
bfloat16: Optional[bool] = None # for non-AMP cases
|
bfloat16: Optional[bool] = None # for non-AMP cases
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
@@ -12,6 +12,8 @@ from datasets import (
|
|||||||
load_dataset,
|
load_dataset,
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.utils import HFValidationError
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -40,7 +42,6 @@ from axolotl.prompters import (
|
|||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.shared import load_dataset_w_config
|
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
md5,
|
md5,
|
||||||
@@ -84,7 +85,6 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Load streaming dataset if pretraining_dataset is given
|
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
split = "train"
|
split = "train"
|
||||||
name = None
|
name = None
|
||||||
@@ -116,18 +116,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
|
|
||||||
# Load eval dataset (non-streaming) if specified
|
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
if cfg.test_datasets:
|
|
||||||
_, eval_dataset, _ = load_prepare_datasets(
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
|
||||||
split="test",
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication:
|
||||||
LOG.info("Deduplication not available for pretrained datasets")
|
LOG.info("Deduplication not available for pretrained datasets")
|
||||||
|
|
||||||
@@ -254,9 +243,195 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
config_dataset, use_auth_token
|
ds_from_hub = False
|
||||||
)
|
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||||
|
try:
|
||||||
|
# this is just a basic check to see if the path is a
|
||||||
|
# valid HF dataset that's loadable
|
||||||
|
load_dataset(
|
||||||
|
config_dataset.path,
|
||||||
|
name=config_dataset.name,
|
||||||
|
streaming=True,
|
||||||
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=ds_trust_remote_code,
|
||||||
|
)
|
||||||
|
ds_from_hub = True
|
||||||
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
ds_from_cloud = False
|
||||||
|
storage_options = {}
|
||||||
|
remote_file_system = None
|
||||||
|
if config_dataset.path.startswith("s3://"):
|
||||||
|
try:
|
||||||
|
import aiobotocore.session # type: ignore
|
||||||
|
import s3fs # type: ignore
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Takes credentials from ~/.aws/credentials for default profile
|
||||||
|
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||||
|
storage_options = {"session": s3_session}
|
||||||
|
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||||
|
elif config_dataset.path.startswith(
|
||||||
|
"gs://"
|
||||||
|
) or config_dataset.path.startswith("gcs://"):
|
||||||
|
try:
|
||||||
|
import gcsfs # type: ignore
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# gcsfs will use default credentials from the environment else anon
|
||||||
|
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||||
|
storage_options = {"token": None}
|
||||||
|
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||||
|
# TODO: Figure out how to get auth creds passed
|
||||||
|
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||||
|
# try:
|
||||||
|
# import adlfs
|
||||||
|
# except ImportError as exc:
|
||||||
|
# raise ImportError(
|
||||||
|
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||||
|
# ) from exc
|
||||||
|
|
||||||
|
# # Gen 1
|
||||||
|
# storage_options = {
|
||||||
|
# "tenant_id": TENANT_ID,
|
||||||
|
# "client_id": CLIENT_ID,
|
||||||
|
# "client_secret": CLIENT_SECRET,
|
||||||
|
# }
|
||||||
|
# # Gen 2
|
||||||
|
# storage_options = {
|
||||||
|
# "account_name": ACCOUNT_NAME,
|
||||||
|
# "account_key": ACCOUNT_KEY,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||||
|
try:
|
||||||
|
if remote_file_system and remote_file_system.exists(
|
||||||
|
config_dataset.path
|
||||||
|
):
|
||||||
|
ds_from_cloud = True
|
||||||
|
except (FileNotFoundError, ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# prefer local dataset, even if hub exists
|
||||||
|
local_path = Path(config_dataset.path)
|
||||||
|
if local_path.exists():
|
||||||
|
if local_path.is_dir():
|
||||||
|
if config_dataset.data_files:
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.data_files,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ds = load_from_disk(config_dataset.path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
ds = load_dataset(
|
||||||
|
config_dataset.path,
|
||||||
|
name=config_dataset.name,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
|
elif local_path.is_file():
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||||
|
)
|
||||||
|
elif ds_from_hub:
|
||||||
|
load_ds_kwargs = {}
|
||||||
|
if config_dataset.split:
|
||||||
|
load_ds_kwargs["split"] = config_dataset.split
|
||||||
|
ds = load_dataset(
|
||||||
|
config_dataset.path,
|
||||||
|
name=config_dataset.name,
|
||||||
|
streaming=False,
|
||||||
|
data_files=config_dataset.data_files,
|
||||||
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
|
**load_ds_kwargs,
|
||||||
|
)
|
||||||
|
elif ds_from_cloud and remote_file_system:
|
||||||
|
if remote_file_system.isdir(config_dataset.path):
|
||||||
|
ds = load_from_disk(
|
||||||
|
config_dataset.path,
|
||||||
|
storage_options=storage_options,
|
||||||
|
)
|
||||||
|
elif remote_file_system.isfile(config_dataset.path):
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
|
)
|
||||||
|
elif config_dataset.path.startswith("https://"):
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if isinstance(config_dataset.data_files, str):
|
||||||
|
fp = hf_hub_download(
|
||||||
|
repo_id=config_dataset.path,
|
||||||
|
repo_type="dataset",
|
||||||
|
filename=config_dataset.data_files,
|
||||||
|
revision=config_dataset.revision,
|
||||||
|
)
|
||||||
|
elif isinstance(config_dataset.data_files, list):
|
||||||
|
fp = []
|
||||||
|
for file in config_dataset.data_files:
|
||||||
|
fp.append(
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=config_dataset.path,
|
||||||
|
repo_type="dataset",
|
||||||
|
filename=file,
|
||||||
|
revision=config_dataset.revision,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"data_files must be either a string or list of strings"
|
||||||
|
)
|
||||||
|
ds = load_dataset(
|
||||||
|
"json",
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=fp,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
|
if not ds:
|
||||||
|
raise ValueError("unhandled dataset load")
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
d_type = config_dataset.type
|
d_type = config_dataset.type
|
||||||
@@ -326,6 +501,24 @@ def load_tokenized_prepared_datasets(
|
|||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
|
def get_ds_type(config_dataset: DictDefault):
|
||||||
|
"""
|
||||||
|
Get the dataset type from the path if it's not specified
|
||||||
|
"""
|
||||||
|
ds_type = "json"
|
||||||
|
if config_dataset.ds_type:
|
||||||
|
ds_type = config_dataset.ds_type
|
||||||
|
elif ".parquet" in config_dataset.path:
|
||||||
|
ds_type = "parquet"
|
||||||
|
elif ".arrow" in config_dataset.path:
|
||||||
|
ds_type = "arrow"
|
||||||
|
elif ".csv" in config_dataset.path:
|
||||||
|
ds_type = "csv"
|
||||||
|
elif ".txt" in config_dataset.path:
|
||||||
|
ds_type = "text"
|
||||||
|
return ds_type
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
|
|||||||
@@ -1,222 +0,0 @@
|
|||||||
"""
|
|
||||||
dataset loading shared utils
|
|
||||||
"""
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from huggingface_hub.errors import HFValidationError
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
def get_ds_type(config_dataset: DictDefault):
|
|
||||||
"""
|
|
||||||
Get the dataset type from the path if it's not specified
|
|
||||||
"""
|
|
||||||
ds_type = "json"
|
|
||||||
if config_dataset.ds_type:
|
|
||||||
ds_type = config_dataset.ds_type
|
|
||||||
elif ".parquet" in config_dataset.path:
|
|
||||||
ds_type = "parquet"
|
|
||||||
elif ".arrow" in config_dataset.path:
|
|
||||||
ds_type = "arrow"
|
|
||||||
elif ".csv" in config_dataset.path:
|
|
||||||
ds_type = "csv"
|
|
||||||
elif ".txt" in config_dataset.path:
|
|
||||||
ds_type = "text"
|
|
||||||
return ds_type
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_w_config(config_dataset, auth_token):
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
|
||||||
ds_from_hub = False
|
|
||||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
|
||||||
try:
|
|
||||||
# this is just a basic check to see if the path is a
|
|
||||||
# valid HF dataset that's loadable
|
|
||||||
load_dataset(
|
|
||||||
config_dataset.path,
|
|
||||||
name=config_dataset.name,
|
|
||||||
streaming=True,
|
|
||||||
token=auth_token,
|
|
||||||
revision=config_dataset.revision,
|
|
||||||
trust_remote_code=ds_trust_remote_code,
|
|
||||||
)
|
|
||||||
ds_from_hub = True
|
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
ds_from_cloud = False
|
|
||||||
storage_options = {}
|
|
||||||
remote_file_system = None
|
|
||||||
if config_dataset.path.startswith("s3://"):
|
|
||||||
try:
|
|
||||||
import aiobotocore.session # type: ignore
|
|
||||||
import s3fs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Takes credentials from ~/.aws/credentials for default profile
|
|
||||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
|
||||||
storage_options = {"session": s3_session}
|
|
||||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
|
||||||
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
|
||||||
"gcs://"
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
import gcsfs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# gcsfs will use default credentials from the environment else anon
|
|
||||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
|
||||||
storage_options = {"token": None}
|
|
||||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
|
||||||
# TODO: Figure out how to get auth creds passed
|
|
||||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
|
||||||
# try:
|
|
||||||
# import adlfs
|
|
||||||
# except ImportError as exc:
|
|
||||||
# raise ImportError(
|
|
||||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
|
||||||
# ) from exc
|
|
||||||
|
|
||||||
# # Gen 1
|
|
||||||
# storage_options = {
|
|
||||||
# "tenant_id": TENANT_ID,
|
|
||||||
# "client_id": CLIENT_ID,
|
|
||||||
# "client_secret": CLIENT_SECRET,
|
|
||||||
# }
|
|
||||||
# # Gen 2
|
|
||||||
# storage_options = {
|
|
||||||
# "account_name": ACCOUNT_NAME,
|
|
||||||
# "account_key": ACCOUNT_KEY,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
|
||||||
try:
|
|
||||||
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
|
||||||
ds_from_cloud = True
|
|
||||||
except (FileNotFoundError, ConnectionError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
|
||||||
local_path = Path(config_dataset.path)
|
|
||||||
if local_path.exists():
|
|
||||||
if local_path.is_dir():
|
|
||||||
if config_dataset.data_files:
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.data_files,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
ds = load_from_disk(
|
|
||||||
config_dataset.path
|
|
||||||
) # pylint: disable=invalid-name
|
|
||||||
except FileNotFoundError:
|
|
||||||
ds = load_dataset(
|
|
||||||
config_dataset.path,
|
|
||||||
name=config_dataset.name,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
elif local_path.is_file():
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
|
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
|
||||||
)
|
|
||||||
elif ds_from_hub:
|
|
||||||
load_ds_kwargs = {}
|
|
||||||
if config_dataset.split:
|
|
||||||
load_ds_kwargs["split"] = config_dataset.split
|
|
||||||
ds = load_dataset(
|
|
||||||
config_dataset.path,
|
|
||||||
name=config_dataset.name,
|
|
||||||
streaming=False,
|
|
||||||
data_files=config_dataset.data_files,
|
|
||||||
token=auth_token,
|
|
||||||
revision=config_dataset.revision,
|
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
**load_ds_kwargs,
|
|
||||||
)
|
|
||||||
elif ds_from_cloud and remote_file_system:
|
|
||||||
if remote_file_system.isdir(config_dataset.path):
|
|
||||||
ds = load_from_disk(
|
|
||||||
config_dataset.path,
|
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
|
||||||
elif remote_file_system.isfile(config_dataset.path):
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
ds = load_dataset(
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
)
|
|
||||||
elif config_dataset.path.startswith("https://"):
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
ds = load_dataset(
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if isinstance(config_dataset.data_files, str):
|
|
||||||
fp = hf_hub_download(
|
|
||||||
repo_id=config_dataset.path,
|
|
||||||
repo_type="dataset",
|
|
||||||
filename=config_dataset.data_files,
|
|
||||||
revision=config_dataset.revision,
|
|
||||||
)
|
|
||||||
elif isinstance(config_dataset.data_files, list):
|
|
||||||
fp = []
|
|
||||||
for file in config_dataset.data_files:
|
|
||||||
fp.append(
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=config_dataset.path,
|
|
||||||
repo_type="dataset",
|
|
||||||
filename=file,
|
|
||||||
revision=config_dataset.revision,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("data_files must be either a string or list of strings")
|
|
||||||
ds = load_dataset(
|
|
||||||
"json",
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=fp,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
if not ds:
|
|
||||||
raise ValueError("unhandled dataset load")
|
|
||||||
|
|
||||||
return ds
|
|
||||||
Reference in New Issue
Block a user