diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 76782be4f..4fd850368 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -120,7 +120,6 @@ def load_model( base_model, trust_remote_code=True if cfg.trust_remote_code is True else False, ) - config.attn_config['attn_impl'] = 'triton' model = AutoModelForCausalLM.from_pretrained( base_model, config=config, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ef1ac95b..3feec4d92 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.logging_steps is not None else max(min(int(0.005 * total_num_steps), 10), 1) ) - save_steps = ( - cfg.save_steps - if cfg.save_steps is not None - else min(int(0.05 * total_num_steps), 200) - ) - eval_steps = ( - cfg.eval_steps - if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0 - else save_steps - ) + save_steps = cfg.save_steps + eval_steps = cfg.eval_steps training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -92,13 +84,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", - save_strategy="steps", + save_strategy="steps" if save_steps else "epoch", eval_steps=eval_steps if cfg.val_set_size > 0 else None, save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, load_best_model_at_end=True - if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True + if cfg.val_set_size > 0 and save_steps is not None and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True else False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, @@ -158,6 +150,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.learning_rate, total_steps=total_num_steps, epochs=cfg.num_epochs, + div_factor=10, **lr_scheduler_kwargs, ) elif cfg.lr_scheduler == "log_sweep":