FDSP + QLoRA (#1378)

* wip qlora + fsdp fixes

* more fixes

* make sure to load the lora 🤦

* only setup quantized meta on non-zero rank:

* only run setup_quantized_peft_meta_for_training for qlora+fsdp

* more fixes for qlora+fsdp

* chore: lint

* add example yml

* support mistral too

* fix for model_type and add mixtral support too

* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp

* refactor for duplicate code
This commit is contained in:
Wing Lian
2024-03-08 14:31:01 -05:00
committed by GitHub
parent 638c2dafb5
commit 9b6ee83a73
8 changed files with 502 additions and 9 deletions

View File

@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -3,7 +3,7 @@ packaging==23.2
peft==0.9.0 peft==0.9.0
transformers==4.38.2 transformers==4.38.2
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.43.0
accelerate==0.26.1 accelerate==0.26.1
deepspeed==0.13.1 deepspeed==0.13.1
pydantic==2.6.3 pydantic==2.6.3
@@ -40,3 +40,4 @@ gcsfs
# adlfs # adlfs
trl>=0.7.9 trl>=0.7.9
fastcore>=1.5.29

View File

View File

@@ -0,0 +1,55 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util import importlib.util
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -17,7 +18,10 @@ from typing import List, Optional, Type, Union
import torch import torch
import transformers import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
@@ -30,6 +34,7 @@ from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1e-6, default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."}, metadata={"help": "loraplus learning rate for lora embedding layers."},
) )
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -468,6 +477,56 @@ class AxolotlTrainer(Trainer):
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """
@@ -787,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed # deepspeed
if self.cfg.deepspeed: if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed

View File

@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
or not torch.cuda.is_available() or not torch.cuda.is_available()
or device == "auto" or device == "auto"
or torch.device(device).type == "cpu" or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
): ):
return default_value return default_value
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper

View File

@@ -1,13 +1,20 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines
import logging import logging
import math import math
import os import os
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import types
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
import safetensors
import torch import torch
import transformers import transformers
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from peft import ( from peft import (
LoftQConfig, LoftQConfig,
PeftConfig, PeftConfig,
@@ -16,6 +23,7 @@ from peft import (
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
from torch import Tensor, nn
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES, SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
return tokenizer return tokenizer
def replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
is_meta_rank: bool = False,
low_memory: bool = True,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
if skip_names is None:
skip_names = []
def place_on_device(value):
if is_meta_rank:
device = "meta"
elif low_memory:
device = "cpu"
else:
device = "cuda"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def load_model( def load_model(
cfg: DictDefault, cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
@@ -394,7 +515,7 @@ def load_model(
if max_memory is not None: if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map
with init_empty_weights(): with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config) model_canvas = AutoModelForCausalLM.from_config(model_config)
@@ -496,8 +617,78 @@ def load_model(
model_kwargs["attn_implementation"] = "eager" model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access model_config._attn_implementation = "eager" # pylint: disable=protected-access
qlora_fsdp = (
cfg.fsdp
and cfg.adapter == "qlora"
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
)
try: try:
if ( if qlora_fsdp:
if cfg.bf16 or cfg.bfloat16:
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
elif cfg.fp16 or cfg.float16:
torch_dtype, compute_dtype = torch.float32, torch.float16
else:
torch_dtype, compute_dtype = torch.float32, torch.float16
with init_empty_weights():
LOG.info("Loading model with empty weights.")
model = AutoModelForCausalLM.from_config(model_config)
model.model = replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=torch_dtype,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
param_count = sum((p.numel() for n, p in model.named_parameters()))
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
right = int(
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
)
n_workers = min(left, right)
parallel(
load_and_quantize_parallel,
weights.items(),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=torch_dtype,
device=cfg.local_rank,
skip_names=[],
is_meta_rank=(cfg.local_rank != 0),
verbose=False,
quant_method=quant_method,
)
elif (
model_config.model_type == "llama" model_config.model_type == "llama"
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
@@ -613,7 +804,7 @@ def load_model(
LOG.exception(err) LOG.exception(err)
raise err raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)): if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
model = model.merge_and_unload() model = model.merge_and_unload()
embeddings_len = ( embeddings_len = (
@@ -692,6 +883,9 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits: if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if qlora_fsdp:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]: if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@@ -706,7 +900,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or cfg.flash_attention: if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -724,7 +918,12 @@ def load_model(
else: else:
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): if (
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp
):
# TODO revaldate this conditional # TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -813,6 +1012,30 @@ def find_all_linear_names(model):
return list(lora_module_names) return list(lora_module_names)
def setup_quantized_meta_for_peft(model: nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model: nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def load_lora(model, cfg, inference=False, config_only=False): def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
@@ -849,6 +1072,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
if config_only: if config_only:
return None, lora_config return None, lora_config
rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA") LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {} model_kwargs: Any = {}
@@ -864,6 +1092,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
model.print_trainable_parameters() if rank == 0:
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)
return model, lora_config return model, lora_config