Compare commits
1 Commits
nca-pair
...
fsdp-qdora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a7c56f018 |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -32,11 +32,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -30,11 +30,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -91,11 +86,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -29,11 +29,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -91,11 +86,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.2.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,7 +133,6 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
venv3.10/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
@@ -138,7 +138,7 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||
rl:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
@@ -227,12 +227,6 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||
|
||||
peft:
|
||||
# Configuration options for loftq initialization for LoRA
|
||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
|
||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: chat_template.argilla
|
||||
type: orpo.chat_template
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
|
||||
@@ -22,6 +22,7 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
qlora_fsdp_alt_loader: true
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
|
||||
@@ -22,6 +22,7 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
qlora_fsdp_alt_loader: true
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
rl: orpo
|
||||
orpo_alpha: 0.1
|
||||
remove_unused_columns: false
|
||||
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: chat_template.argilla
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./mistral-qlora-orpo-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -39,6 +39,6 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||
trl==0.8.5
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -33,7 +33,7 @@ fi
|
||||
|
||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||
# Run Jupyter Lab in the background
|
||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
|
||||
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||
temperature=cfg.get("gradio_temperature", 0.9),
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
@@ -300,13 +300,7 @@ def do_inference_gradio(
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
@@ -439,23 +433,6 @@ def load_rl_datasets(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
@@ -43,7 +43,6 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelOnTrainEndCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -213,10 +212,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -352,8 +347,6 @@ class AxolotlTrainer(Trainer):
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
@@ -889,14 +882,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -942,11 +927,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -1201,7 +1193,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = None
|
||||
if self.cfg.use_wandb:
|
||||
report_to = "wandb"
|
||||
@@ -1429,8 +1420,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -1466,7 +1455,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
@@ -1525,10 +1513,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_cls = TrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
training_args_cls = DPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1553,8 +1537,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
|
||||
dpo_trainer_kwargs["loss_type"] = self.cfg.rl
|
||||
elif self.cfg.rl == "kto_pair":
|
||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
@@ -1563,7 +1547,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
@@ -1573,8 +1557,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
if self.cfg.rl == "dpo":
|
||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
DPO strategies for mistral instruct
|
||||
"""
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||
sample["chosen"] = f"{sample['chosen']}"
|
||||
sample["rejected"] = f"{sample['rejected']}"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
@@ -128,20 +127,14 @@ def train(
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
_model = model_weakref()
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
_model = BetterTransformer.reverse(_model)
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
_model_weakref = weakref.ref(model)
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||
|
||||
@@ -773,13 +773,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
|
||||
@@ -133,8 +133,6 @@ class RLType(str, Enum):
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
||||
nca_pair = "nca_pair" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
@@ -190,6 +188,7 @@ class LoraConfig(BaseModel):
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
qlora_fsdp_alt_loader: Optional[bool] = None
|
||||
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
gptq: Optional[bool] = None
|
||||
@@ -411,17 +410,6 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
gradio_title: Optional[str] = None
|
||||
gradio_share: Optional[bool] = None
|
||||
gradio_server_name: Optional[str] = None
|
||||
gradio_server_port: Optional[int] = None
|
||||
gradio_max_new_tokens: Optional[int] = None
|
||||
gradio_temperature: Optional[float] = None
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
@@ -432,7 +420,6 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
@@ -517,17 +504,9 @@ class AxolotlInputConfig(
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
sample_packing: Optional[bool] = None
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
|
||||
# for PoSE context length extension
|
||||
use_pose: Optional[bool] = None
|
||||
pose_split_on_token_ids: Optional[List[int]] = None
|
||||
pose_max_context_len: Optional[int] = None
|
||||
pose_num_chunks: Optional[int] = None
|
||||
|
||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||
pretrain_multipack_attn: Optional[bool] = Field(
|
||||
@@ -576,7 +555,6 @@ class AxolotlInputConfig(
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
orpo_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
@@ -795,11 +773,11 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_push_save(cls, data):
|
||||
if data.get("hub_model_id") and (
|
||||
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||
if data.get("hub_model_id") and not (
|
||||
data.get("save_steps") or data.get("saves_per_epoch")
|
||||
):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ def load_and_quantize(
|
||||
to_meta: bool = False,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
is_dora: bool = False,
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
@@ -108,6 +109,12 @@ def load_and_quantize(
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
if is_dora:
|
||||
setattr(
|
||||
submodule,
|
||||
"dora_scale",
|
||||
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
|
||||
)
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
@@ -177,6 +184,7 @@ def load_sharded_model_quant(
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
if hasattr(model, "transformer"):
|
||||
@@ -249,6 +257,7 @@ def load_sharded_model_quant(
|
||||
to_meta=(low_memory and cfg.local_rank != 0),
|
||||
verbose=verbose,
|
||||
quant_method=quant_method,
|
||||
is_dora=cfg.peft_use_dora,
|
||||
)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
|
||||
@@ -34,6 +34,7 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.quantizers import AutoHfQuantizer
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
@@ -568,7 +569,7 @@ def load_model(
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.model_config_type == "dbrx"
|
||||
and cfg.qlora_fsdp_alt_loader
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
model = load_sharded_model_quant(
|
||||
@@ -577,6 +578,11 @@ def load_model(
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
if model_kwargs["quantization_config"]:
|
||||
hf_quantizer = AutoHfQuantizer.from_config(
|
||||
model_kwargs["quantization_config"]
|
||||
)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
@@ -789,11 +795,7 @@ def load_model(
|
||||
if not reference_model or cfg.lora_model_dir:
|
||||
# if we're not loading the reference model, then we're loading the model for training
|
||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||
if (
|
||||
cfg.adapter
|
||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]
|
||||
and not cfg.merge_lora
|
||||
):
|
||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
@@ -1007,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16):
|
||||
module.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
for name, param in model.named_parameters():
|
||||
try:
|
||||
if param.data.dtype != dtype:
|
||||
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
|
||||
param.data = param.data.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
@@ -9,19 +10,10 @@ from termcolor import colored
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_dataset_labels(
|
||||
dataset,
|
||||
tokenizer,
|
||||
num_examples=5,
|
||||
text_only=False,
|
||||
rl_mode=False,
|
||||
):
|
||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
for idx in range(num_examples):
|
||||
if not rl_mode:
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
else:
|
||||
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||
|
||||
|
||||
def check_example_labels(example, tokenizer, text_only=False):
|
||||
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||
"""Helper function to color tokens based on their type."""
|
||||
colored_text = colored(decoded_token, color)
|
||||
return (
|
||||
colored_text
|
||||
if text_only
|
||||
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||
)
|
||||
|
||||
|
||||
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||
"""Helper function to process and color tokens."""
|
||||
colored_tokens = [
|
||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||
for token in tokenizer.encode(tokens)
|
||||
]
|
||||
return colored_tokens
|
||||
|
||||
|
||||
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||
|
||||
input_tokens = example[field_prompt]
|
||||
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||
|
||||
# Process and color each type of token
|
||||
colored_tokens = process_tokens_for_rl_debug(
|
||||
input_tokens, "yellow", tokenizer, text_only
|
||||
)
|
||||
colored_chosens = process_tokens_for_rl_debug(
|
||||
labels_chosen, "green", tokenizer, text_only
|
||||
)
|
||||
colored_rejecteds = process_tokens_for_rl_debug(
|
||||
labels_rejected, "red", tokenizer, text_only
|
||||
)
|
||||
|
||||
# Create a delimiter based on text_only flag
|
||||
delimiter = "" if text_only else " "
|
||||
|
||||
# Logging information
|
||||
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
|
||||
return delimiter.join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -99,89 +98,17 @@ def add_position_ids(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def add_pose_position_ids(
|
||||
sample,
|
||||
max_context_len=32768,
|
||||
split_on_token_ids: Optional[List[int]] = None,
|
||||
chunks: int = 2,
|
||||
):
|
||||
"""
|
||||
use the PoSE technique to extend the context length by randomly skipping
|
||||
positions in the context. We only want to skip right before tokens in
|
||||
the split_on_token_ids list. We should attempt to randomly distribute
|
||||
the skips, but we don't need the final position_ids to be the full
|
||||
context_len. There may be multiple turns in the context, so we want to
|
||||
make sure we take into account the maximum possible number of skips
|
||||
remaining in each sample.
|
||||
"""
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
sample_len = len(input_ids)
|
||||
max_skips = max_context_len - sample_len
|
||||
|
||||
if split_on_token_ids is None:
|
||||
split_on_token_ids = []
|
||||
|
||||
if split_on_token_ids:
|
||||
split_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||
]
|
||||
else:
|
||||
chunk_len = sample_len // chunks
|
||||
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||
if split_indices[0] < 2:
|
||||
# drop the first split index if it's too close to the beginning
|
||||
split_indices = split_indices[1:]
|
||||
|
||||
position_ids = []
|
||||
prev_index = 0
|
||||
total_skips = 0
|
||||
|
||||
for split_index in split_indices:
|
||||
num_skips = (
|
||||
random.randint(0, max_skips) # nosec B311
|
||||
if prev_index != 0 and max_skips
|
||||
else 0
|
||||
)
|
||||
max_skips -= num_skips
|
||||
total_skips += num_skips
|
||||
|
||||
segment_position_ids = list(
|
||||
range(prev_index + total_skips, split_index + total_skips)
|
||||
)
|
||||
|
||||
position_ids.extend(segment_position_ids)
|
||||
prev_index = split_index
|
||||
|
||||
sample["sequence_len"] = position_ids[-1]
|
||||
position_ids = torch.tensor(position_ids)
|
||||
|
||||
sample["position_ids"] = position_ids
|
||||
sample["length"] = len(position_ids)
|
||||
assert len(position_ids) == len(input_ids)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return (
|
||||
len(sample["input_ids"]) <= sequence_len
|
||||
and len(sample["input_ids"]) >= min_sequence_len
|
||||
)
|
||||
def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
with zero_first(is_main_process()):
|
||||
if cfg.is_preprocess:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
@@ -226,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
desc="Group By Length",
|
||||
)
|
||||
|
||||
if cfg.use_pose:
|
||||
pose_kwargs = {}
|
||||
if cfg.pose_num_chunks is not None:
|
||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||
pose_fn = partial(
|
||||
add_pose_position_ids,
|
||||
max_context_len=cfg.pose_max_context_len,
|
||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||
**pose_kwargs,
|
||||
)
|
||||
train_dataset = train_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
train_dataset = train_dataset.sort("sequence_len")
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
num_proc=cfg.dataset_processes,
|
||||
@@ -438,7 +340,7 @@ def prepare_optim_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]:
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
|
||||
@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "orpo",
|
||||
"orpo_alpha": 0.1,
|
||||
"remove_unused_columns": False,
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||
"type": "chat_template.argilla",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert (
|
||||
"set without any models being saved" in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_hub_model_id_save_value(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
Reference in New Issue
Block a user