Compare commits
4 Commits
multipack-
...
3e3229e2d9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e3229e2d9 | ||
|
|
1d21aa6b0a | ||
|
|
71b7ea3c05 | ||
|
|
a48dbf6561 |
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -71,7 +72,7 @@ def do_merge_lora(
|
|||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=torch.float16)
|
model.to(dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
@@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|||||||
@@ -412,12 +412,19 @@ def load_model(
|
|||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
|
skip_prepare_model_for_kbit_training = False
|
||||||
|
|
||||||
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
if not skip_prepare_model_for_kbit_training:
|
||||||
model = prepare_model_for_kbit_training(
|
model = prepare_model_for_kbit_training(
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg):
|
|||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def prepare_optim_env(cfg):
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
Reference in New Issue
Block a user