Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
3ce9b0760b fix the lora yaml for l3 2024-04-19 07:28:07 -04:00
19 changed files with 33 additions and 205 deletions

View File

@@ -22,7 +22,6 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 8 lora_r: 8

View File

@@ -22,7 +22,6 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 8 lora_r: 8

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -64,4 +64,4 @@ weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>

View File

@@ -11,7 +11,7 @@ addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.15.0 datasets>=2.15.0
flash-attn==2.5.5 flash-attn==2.5.5
sentencepiece sentencepiece
wandb wandb
@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 fschat==0.2.36
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl==0.8.5 trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.rl: # and cfg.rl != "orpo": if cfg.rl and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl import DPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -54,7 +54,6 @@ from axolotl.utils.collators import (
MambaDataCollator, MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
@@ -811,14 +810,6 @@ class AxolotlDPOTrainer(DPOTrainer):
return res return res
class AxolotlORPOTrainer(ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -1413,7 +1404,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
class HFRLTrainerBuilder(TrainerBuilderBase): class HFDPOTrainerBuilder(TrainerBuilderBase):
""" """
Trainer factory class for DPO Trainer Trainer factory class for DPO Trainer
""" """
@@ -1506,15 +1497,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch" training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.orpo_alpha: training_args = TrainingArguments(
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = TrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps, max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
@@ -1547,32 +1530,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: dpo_trainer = AxolotlDPOTrainer(
trainer_cls = AxolotlDPOTrainer self.model,
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 self.model_ref,
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args, args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer) dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)

View File

@@ -123,14 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role, "" yield role, ""
return return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM: if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926

View File

@@ -6,4 +6,4 @@ from functools import partial
from ..base import load as load_base from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo") load = partial(load_base, module="axolotl.prompt_strategies.orpo")

View File

@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
) )
return MessageList(messages=messages) return MessageList(messages=messages)
def get_prompt(self, prompt) -> MessageList:
"""Map the data to extract everything up to the last turn"""
total_msg_len = len(prompt["chosen"])
total_msg_turns, remainder = divmod(total_msg_len, 2)
assert remainder == 0, "invalid number of turns"
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
for i in range(total_msg_turns):
if "prompt" in prompt:
messages.append(
Message(role="user", content=prompt["prompt"], label=False)
)
else:
messages.append(
Message(
role="user",
content=prompt["chosen"][i * 2]["content"],
label=False,
)
)
if i < total_msg_turns - 1:
messages.append(
Message(
role="assistant",
content=prompt["chosen"][i * 2 + 1]["content"],
label=False,
)
)
return MessageList(messages=messages)
def get_chosen(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["chosen"][-1]["content"], label=True
)
)
return res
def get_rejected(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["rejected"][-1]["content"], label=True
)
)
return res
class ORPOTokenizingStrategy(PromptTokenizingStrategy): class ORPOTokenizingStrategy(PromptTokenizingStrategy):
""" """
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
), True ), True
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
return res
return transform_fn

View File

@@ -188,7 +188,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_fsdp_alt_loader: Optional[bool] = None
lora_on_cpu: Optional[bool] = None lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None gptq: Optional[bool] = None

View File

@@ -1,11 +1,11 @@
""" """
Data processing modules Data processing modules
""" """
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -1,20 +1,17 @@
"""data handling specific to DPO""" """data handling specific to DPO"""
import inspect
import logging import logging
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, List from typing import Any, List
import yaml import yaml
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, load_dataset, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
) )
split_datasets.insert(i, ds) split_datasets.insert(i, ds)
tokenizer = None
for i, data_set in enumerate(split_datasets): for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"] _type = dataset_cfgs[i]["type"]
if _type: if _type:
if isinstance(_type, DictDefault): if isinstance(_type, DictDefault):
_type = "user_defined.default" _type = "user_defined.default"
if _cfg.rl == "orpo": ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) split_datasets[i] = data_set.map(
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(_cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
data_set = data_set.map(
ds_transform_fn, ds_transform_fn,
desc="Mapping RL Dataset", desc="Mapping RL Dataset",
) )
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
split_datasets[i] = data_set
else: else:
# If no `type` is provided, assume the dataset is already in the expected format with # If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed # "prompt", "chosen" and "rejected" already preprocessed

View File

@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path)) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"

View File

@@ -70,7 +70,6 @@ def load_and_quantize(
to_meta: bool = False, to_meta: bool = False,
verbose: bool = False, verbose: bool = False,
quant_method: str = "bnb", quant_method: str = "bnb",
is_dora: bool = False,
): ):
""" """
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
@@ -109,12 +108,6 @@ def load_and_quantize(
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then # workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0. # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
setattr(
submodule,
"dora_scale",
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
)
value = type(param)( value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__ value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device) ).cuda(device)
@@ -184,7 +177,6 @@ def load_sharded_model_quant(
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
model_config, model_config,
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
if hasattr(model, "transformer"): if hasattr(model, "transformer"):
@@ -257,7 +249,6 @@ def load_sharded_model_quant(
to_meta=(low_memory and cfg.local_rank != 0), to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose, verbose=verbose,
quant_method=quant_method, quant_method=quant_method,
is_dora=cfg.peft_use_dora,
) )
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:

View File

@@ -34,7 +34,6 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.quantizers import AutoHfQuantizer
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -569,7 +568,7 @@ def load_model(
elif ( elif (
qlora_fsdp qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.qlora_fsdp_alt_loader and cfg.model_config_type == "dbrx"
): ):
quant_storage = cfg.torch_dtype quant_storage = cfg.torch_dtype
model = load_sharded_model_quant( model = load_sharded_model_quant(
@@ -578,11 +577,6 @@ def load_model(
cfg, cfg,
quant_storage=quant_storage, quant_storage=quant_storage,
) )
if model_kwargs["quantization_config"]:
hf_quantizer = AutoHfQuantizer.from_config(
model_kwargs["quantization_config"]
)
model.hf_quantizer = hf_quantizer
skip_move_to_device = True skip_move_to_device = True
elif ( elif (
model_config.model_type == "llama" model_config.model_type == "llama"
@@ -999,20 +993,3 @@ def load_lora(model, cfg, inference=False, config_only=False):
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)
return model, lora_config return model, lora_config
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass
for name, param in model.named_parameters():
try:
if param.data.dtype != dtype:
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
param.data = param.data.to(dtype)
except AttributeError:
pass

View File

@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: if cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]
else: else:

View File

@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
import pytest import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer) return load_model(cfg, tokenizer)
class TestHFRLTrainerBuilder: class TestHFDPOTrainerBuilder:
""" """
TestCase class for DPO trainer builder TestCase class for DPO trainer builder
""" """
def test_build_training_arguments(self, cfg, model, tokenizer): def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100) training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9 assert training_arguments.adam_beta2 == 0.9

View File

@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset" tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name)) self.dataset.save_to_disk(tmp_ds_name)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(