diff --git a/docs/lr_groups.qmd b/docs/lr_groups.qmd new file mode 100644 index 000000000..52059016c --- /dev/null +++ b/docs/lr_groups.qmd @@ -0,0 +1,29 @@ +--- +title: Learning Rate Groups +description: "Setting different learning rates by module name" +--- + +## Background + +Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of +modules in a model. + +## Example + +```yaml +lr_groups: + - name: o_proj + modules: + - self_attn.o_proj.weight + lr: 1e-6 + - name: q_proj + modules: + - model.layers.2.self_attn.q_proj.weight + lr: 1e-5 + +learning_rate: 2e-5 +``` + +In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate +of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's +self attention `q_proj` module. diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e6d3ae2b7..c7340e4f5 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -243,6 +243,10 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "Scale the learning rate for the embedding layers."}, ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) embedding_lr: Optional[float] = field( default=None, metadata={"help": "absolute learning rate for the embedding layers."}, @@ -461,11 +465,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) + def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): + decay_parameters = self.get_decay_parameter_names(opt_model) + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + def create_optimizer(self): if ( self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None + and self.args.lr_groups is None and self.args.alternate_optimizer not in [ "optimi_adamw", @@ -479,59 +567,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) @@ -548,6 +590,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): elif ( self.args.embedding_lr_scale is not None or self.args.embedding_lr is not None + or self.args.lr_groups is not None ): self.optimizer = ( # pylint: disable=attribute-defined-outside-init optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) @@ -1665,6 +1708,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.loraplus_lr_embedding training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" @@ -1880,6 +1924,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if training_args.pretraining: if self.cfg.pretraining_sample_concatenation is False: return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) + if self.cfg.micro_batch_size > 1: + return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 98cdee009..44e247886 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -147,6 +147,14 @@ class UserDefinedPrompterType(BaseModel): field: Optional[str] = None +class LrGroup(BaseModel): + """Custom learning rate group configuration""" + + name: str + modules: List[str] + lr: float + + class SFTDataset(BaseModel): """SFT configuration subset""" @@ -475,6 +483,7 @@ class HyperparametersConfig(BaseModel): cosine_min_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None lr_div_factor: Optional[float] = None + lr_groups: Optional[List[LrGroup]] = None adam_epsilon: Optional[float] = None adam_beta1: Optional[float] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index c30d62575..f20ced221 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -191,7 +191,7 @@ def wrap_pretraining_dataset( tokenizer, return_tensors="pt", padding=True, - pad_to_multiple_of=max_tokens * batch_size, + pad_to_multiple_of=max_tokens, multipack_attn=cfg.pretrain_multipack_attn, ) encode = functools.partial( @@ -201,8 +201,6 @@ def wrap_pretraining_dataset( max_seq_length=max_tokens, batch_size=batch_size, multipack_attn=cfg.pretrain_multipack_attn, - group_size=cfg.sample_packing_group_size, - bin_size=cfg.sample_packing_bin_size, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 @@ -247,9 +245,7 @@ def encode_packed_pretraining( examples: Dict[str, List], max_seq_length: int = 2048, batch_size: int = 4, - multipack_attn: Optional[bool] = False, - group_size: int = 100000, - bin_size: int = 200, + multipack_attn: Optional[bool] = True, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -260,6 +256,9 @@ def encode_packed_pretraining( train_dataset, max_seq_length, skip_position_ids=not multipack_attn, + # FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm + # workaround by using the position id logic for now in trainer + drop_attention_mask=multipack_attn, ) sampler = MultipackBatchSampler( @@ -267,8 +266,6 @@ def encode_packed_pretraining( lengths=get_dataset_lengths(train_dataset), batch_size=1, batch_max_len=batch_size * max_seq_length, - group_size=group_size, - bin_size=bin_size, drop_last=True, ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 34b505ff1..bfd21703d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -310,19 +310,22 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_pretraining_datasets_for_packing( - train_dataset, sequence_len, skip_position_ids=True + train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False ): drop_long = partial(drop_long_seq, sequence_len=sequence_len) train_dataset = train_dataset.filter( drop_long, desc="Dropping Long Sequences", + load_from_cache_file=False, ) - if skip_position_ids: + if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids, desc="Add position_id column (Pretraining Sample Packing)", ) + if drop_attention_mask: + train_dataset = train_dataset.remove_columns("attention_mask") return train_dataset diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 117eba25d..c1f024b87 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists +from .utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -28,19 +28,25 @@ class TestPretrainLlama: "sample_packing", [True, False], ) - def test_pretrain(self, temp_dir, sample_packing): + @pytest.mark.parametrize( + "pretrain_multipack_attn", + [True, False], + ) + def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): + if not sample_packing and pretrain_multipack_attn: + return + # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", + "base_model": "HuggingFaceTB/SmolLM2-135M", "flash_attention": True, "sequence_len": 1024, "sample_packing": sample_packing, + "pretrain_multipack_attn": pretrain_multipack_attn, + "dataset_processes": 1, "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", + "pad_token": "<|endoftext|>", }, "pretraining_dataset": [ { @@ -51,7 +57,7 @@ class TestPretrainLlama: ], "max_steps": 5, "num_epochs": 1, - "micro_batch_size": 1, + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "val_set_size": 0.0, "output_dir": temp_dir, @@ -60,6 +66,7 @@ class TestPretrainLlama: "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", + "use_tensorboard": True, } ) normalize_config(cfg) @@ -68,3 +75,12 @@ class TestPretrainLlama: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + loss_threshold = 3.5 + if sample_packing and not pretrain_multipack_attn: + loss_threshold = 6.5 + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", + ) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index fbb776aa5..9f9ae60fb 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -41,6 +41,7 @@ class TestPretrainingPacking(unittest.TestCase): } ], "sample_packing": True, + "pretrain_multipack_attn": True, "pad_to_sequence_len": True, "sequence_len": 2048, "micro_batch_size": 2, @@ -87,9 +88,11 @@ class TestPretrainingPacking(unittest.TestCase): assert data["labels"].shape == torch.Size( [1, original_bsz * cfg.sequence_len] ) - assert data["attention_mask"].shape == torch.Size( - [1, original_bsz * cfg.sequence_len] - ) + assert "attention_mask" not in data + # FIXME add back once we fix packing unpad/pad with attention mask + # assert data["attention_mask"].shape == torch.Size( + # [1, original_bsz * cfg.sequence_len] + # ) idx += 1