Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-08-06 21:23:14 +05:30
committed by GitHub
20 changed files with 204 additions and 145 deletions

View File

@@ -8,6 +8,8 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -54,6 +54,14 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt
{"conversations": [{"from": "...", "value": "..."}]}
```
## sharegpt.load_ultrachat
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```
## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny

View File

@@ -43,7 +43,6 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""

View File

@@ -74,3 +74,5 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3.1-405B
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer
load_in_4bit: true
@@ -10,10 +10,11 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora
sequence_len: 1024
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
@@ -25,7 +26,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -1,9 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.43.3
peft==0.12.0
transformers==4.43.4
tokenizers==0.19.1
bitsandbytes==0.43.1
bitsandbytes==0.43.3
accelerate==0.32.0
deepspeed==0.14.4
pydantic==2.6.3

View File

@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)

View File

@@ -11,4 +11,5 @@ MOE_ARCH_BLOCK = {
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -242,6 +242,12 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
@dataclass
@@ -318,7 +324,23 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -876,37 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -1190,10 +1181,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
@@ -1243,7 +1230,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -1441,12 +1430,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)

View File

@@ -25,12 +25,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None):
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils"):
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -10,8 +10,8 @@ from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
class ChatTemplatePrompter(Prompter):
@@ -359,7 +359,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train"),
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "last"),
}

View File

@@ -23,6 +23,7 @@ _TEMPLATES = {
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}",
}

View File

@@ -39,6 +39,7 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
mistral = "mistral" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
@@ -258,6 +259,12 @@ class LoraConfig(BaseModel):
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
@@ -395,7 +402,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -978,6 +985,8 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if (
data.get("sample_packing")
and data.get("eval_table_size")

View File

@@ -160,8 +160,12 @@ def load_tokenized_prepared_datasets(
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
cfg.push_dataset_to_hub,
ds_hash,
token=use_auth_token,
)
dataset = dataset[split]
@@ -181,7 +185,14 @@ def load_tokenized_prepared_datasets(
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
if cfg.push_dataset_to_hub:
LOG.info("Unable to find prepared dataset in Huggingface hub")
if cfg.is_preprocess:
LOG.info(
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
)
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
@@ -433,10 +444,12 @@ def load_tokenized_prepared_datasets(
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
cfg.push_dataset_to_hub,
ds_hash,
private=True,
)
return dataset, prompters

View File

@@ -153,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
0.0, device=torch.cuda.current_device(), dtype=torch.float32
) # Placeholder tensor
# Broadcast the tensor to all processes.

View File

@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
low_memory=True,
verbose=False,
loading_workers=2,
quantization_config=None,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
else:
# this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
model.is_loaded_in_4bit = True
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
quant_method=quant_method,
)
# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading

View File

@@ -351,7 +351,11 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
patch_for_multipack(
cfg.model_config_type,
model_name=cfg.base_model,
is_remote_code=cfg.trust_remote_code,
)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -627,14 +631,21 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
):
quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
@@ -1014,7 +1025,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
@@ -183,88 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
train_dataset = train_dataset.filter(
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
return train_dataset, eval_dataset
@@ -393,10 +393,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
@@ -444,6 +440,12 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)