diff --git a/ds_config.json b/ds_config.json index 49de5f874..65955377c 100644 --- a/ds_config.json +++ b/ds_config.json @@ -20,10 +20,12 @@ } }, "scheduler": { - "type": "OneCycle", + "type": "WarmupDecayLR", "params": { - "cycle_min_lr": 1e-7, - "cycle_max_lr": 1e-4 + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" } }, "zero_optimization": { diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a780dea01..89d6f9d14 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -101,19 +101,12 @@ def load_model( ) load_in_8bit = False elif is_llama_derived_model and "LlamaForCausalLM" in globals(): - if not cfg.load_in_8bit: - model = LlamaForCausalLM.from_pretrained( - base_model, - device_map=cfg.device_map, - ) - else: - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - torch_dtype=torch_dtype, - device_map=cfg.device_map, - ) - + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) elif model_type: model = getattr(transformers, model_type).from_pretrained( base_model,