diff --git a/README.md b/README.md index 5e8d05490..5a2df4eb1 100644 --- a/README.md +++ b/README.md @@ -30,4 +30,24 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml] - Install python dependencies `pip3 install -r requirements.txt` -- Train! `python3 scripts/finetune.py`, make sure to choose the correct YAML config file +- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml` + +```yaml +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +- Train! `accelerate launch scripts/finetune.py`, make sure to choose the correct YAML config file diff --git a/scripts/finetune.py b/scripts/finetune.py index c85994f8f..ffa4d1950 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe from axolotl.flash_attn import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() + torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32, try: if "llama" in base_model: model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, + torch_dtype=torch_dtype, device_map=cfg.device_map, ) else: model = getattr(transformers, model_type).from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, + torch_dtype=torch_dtype, device_map=cfg.device_map, ) except: model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, + torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) training_arguments_kwargs = {} - training_arguments_kwargs["bf16"] = cfg.bf16 + if cfg.bf16 == "full": + training_arguments_kwargs["bf16_full_eval"] = True + else: + training_arguments_kwargs["bf16"] = cfg.bf16 training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps @@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, + gradient_checkpointing=cfg.gradient_checkpointing, **training_arguments_kwargs, ) - trainer_kwargs = {} decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ @@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): lr=training_args.learning_rate, ) + # TODO optionally use torch.optim.OneCycleLR lr_scheduler = transformers.get_cosine_schedule_with_warmup( adam_bnb_optim, training_args.warmup_steps, total_num_steps, ) - trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) + trainer_kwargs = {} if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( cfg.early_stopping_patience, @@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, + optimizers=(adam_bnb_optim, lr_scheduler), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), @@ -342,6 +348,12 @@ def train( cfg.gradient_accumulation_steps // cfg.world_size ) setup_wandb_env_vars(cfg) + if cfg.device == "mps": + cfg.load_in_8bit = False + cfg.tf32 = False + if cfg.bf16: + cfg.fp16 = True + cfg.bf16 = False # Load the model and tokenizer model, tokenizer, lora_config = load_model( diff --git a/setup.cfg b/setup.cfg index 0a41a6876..0822ab495 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,3 +28,6 @@ install_requires = [options.packages.find] where = src +[options.extras_require] +gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda] +gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]