Compare commits
1 Commits
fsdp-qdora
...
fix-l3-lor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ce9b0760b |
@@ -22,7 +22,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
qlora_fsdp_alt_loader: true
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
|
||||
@@ -22,7 +22,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
qlora_fsdp_alt_loader: true
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: meta-llama/Meta-Llama-3-8B
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
@@ -64,4 +64,4 @@ weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
@@ -11,7 +11,7 @@ addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets==2.15.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.5.5
|
||||
sentencepiece
|
||||
wandb
|
||||
@@ -28,7 +28,7 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
||||
fschat==0.2.36
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
@@ -39,6 +39,6 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl==0.8.5
|
||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
@@ -54,7 +54,6 @@ from axolotl.utils.collators import (
|
||||
MambaDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
@@ -811,14 +810,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -1413,7 +1404,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for DPO Trainer
|
||||
"""
|
||||
@@ -1506,15 +1497,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_cls = TrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
|
||||
training_args = training_args_cls(
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
@@ -1547,32 +1530,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
dpo_trainer = AxolotlDPOTrainer(
|
||||
self.model,
|
||||
self.model_ref,
|
||||
args=training_args,
|
||||
beta=self.cfg.dpo_beta or 0.1,
|
||||
train_dataset=self.train_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.cfg.sequence_len,
|
||||
max_target_length=None,
|
||||
max_prompt_length=self.cfg.sequence_len,
|
||||
generate_during_eval=True,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||
|
||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
@@ -123,14 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<bos>" if i == 0 else ""
|
||||
message_str = message if message else ""
|
||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
|
||||
@@ -6,4 +6,4 @@ from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||
|
||||
@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_prompt(self, prompt) -> MessageList:
|
||||
"""Map the data to extract everything up to the last turn"""
|
||||
total_msg_len = len(prompt["chosen"])
|
||||
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
||||
assert remainder == 0, "invalid number of turns"
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
for i in range(total_msg_turns):
|
||||
if "prompt" in prompt:
|
||||
messages.append(
|
||||
Message(role="user", content=prompt["prompt"], label=False)
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=prompt["chosen"][i * 2]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
if i < total_msg_turns - 1:
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=prompt["chosen"][i * 2 + 1]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_chosen(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def get_rejected(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), True
|
||||
|
||||
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
return res
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -188,7 +188,6 @@ class LoraConfig(BaseModel):
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
qlora_fsdp_alt_loader: Optional[bool] = None
|
||||
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
gptq: Optional[bool] = None
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Data processing modules
|
||||
"""
|
||||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""data handling specific to DPO"""
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = None
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if _cfg.rl == "orpo":
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
tokenizer = load_tokenizer(_cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
split_datasets[i] = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
)
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
split_datasets[i] = data_set
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
|
||||
@@ -70,7 +70,6 @@ def load_and_quantize(
|
||||
to_meta: bool = False,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
is_dora: bool = False,
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
@@ -109,12 +108,6 @@ def load_and_quantize(
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
if is_dora:
|
||||
setattr(
|
||||
submodule,
|
||||
"dora_scale",
|
||||
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
|
||||
)
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
@@ -184,7 +177,6 @@ def load_sharded_model_quant(
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
if hasattr(model, "transformer"):
|
||||
@@ -257,7 +249,6 @@ def load_sharded_model_quant(
|
||||
to_meta=(low_memory and cfg.local_rank != 0),
|
||||
verbose=verbose,
|
||||
quant_method=quant_method,
|
||||
is_dora=cfg.peft_use_dora,
|
||||
)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
|
||||
@@ -34,7 +34,6 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.quantizers import AutoHfQuantizer
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
@@ -569,7 +568,7 @@ def load_model(
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.qlora_fsdp_alt_loader
|
||||
and cfg.model_config_type == "dbrx"
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
model = load_sharded_model_quant(
|
||||
@@ -578,11 +577,6 @@ def load_model(
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
if model_kwargs["quantization_config"]:
|
||||
hf_quantizer = AutoHfQuantizer.from_config(
|
||||
model_kwargs["quantization_config"]
|
||||
)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
@@ -999,20 +993,3 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
try:
|
||||
if module.weight.dtype != dtype:
|
||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
||||
module.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
for name, param in model.named_parameters():
|
||||
try:
|
||||
if param.data.dtype != dtype:
|
||||
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
|
||||
param.data = param.data.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
class TestHFDPOTrainerBuilder:
|
||||
"""
|
||||
TestCase class for DPO trainer builder
|
||||
"""
|
||||
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||
self.dataset.save_to_disk(tmp_ds_name)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user