Compare commits
4 Commits
lora-kerne
...
grpo-ref-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9ebff087c | ||
|
|
b53a41372f | ||
|
|
02f45e94be | ||
|
|
954e192f38 |
@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
|||||||
memory usage during the forward and backward passes of these calculations.
|
memory usage during the forward and backward passes of these calculations.
|
||||||
|
|
||||||
We currently support several common model architectures, including (but not limited to):
|
We currently support several common model architectures, including (but not limited to):
|
||||||
|
|
||||||
- `llama`
|
- `llama`
|
||||||
- `mistral`
|
- `mistral`
|
||||||
- `qwen2`
|
- `qwen2`
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.48.3
|
transformers==4.49.0
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
|
|||||||
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||||
# pylint: enable=access-member-before-definition
|
# pylint: enable=access-member-before-definition
|
||||||
|
|
||||||
|
# cleanup the ref_model if we have a peft model passed in
|
||||||
|
# TODO remove this after next major trl release
|
||||||
|
if (
|
||||||
|
self.ref_model # pylint: disable=access-member-before-definition
|
||||||
|
and is_peft_model(self.model)
|
||||||
|
):
|
||||||
|
del self.ref_model
|
||||||
|
self.ref_model = None
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
def _enable_gradient_checkpointing(
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
self, model: PreTrainedModel, args: GRPOConfig
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
|
|||||||
@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
):
|
):
|
||||||
|
if not optimizer:
|
||||||
|
optimizer = state.optimizer
|
||||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||||
checkpoint_folder = os.path.join(
|
checkpoint_folder = os.path.join(
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
|
|||||||
@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
dict(zip(feature_names, row))
|
dict(zip(feature_names, row))
|
||||||
)
|
)
|
||||||
for key, val in tokenized_prompt.items():
|
for key, val in tokenized_prompt.items():
|
||||||
for i in range(0, len(val), self.sequence_len):
|
res[key].append(val)
|
||||||
res[key].append(val[i : i + self.sequence_len])
|
|
||||||
|
|
||||||
# If there are no examples left, return an empty dictionary
|
# If there are no examples left, return an empty dictionary
|
||||||
if not res:
|
if not res:
|
||||||
|
|||||||
@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}")
|
min_input_len = np.min(ds_lengths)
|
||||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
LOG.info(f"min_input_len: {min_input_len}")
|
||||||
LOG.debug(f"max_input_len: {max_input_len}")
|
max_input_len = np.max(ds_lengths)
|
||||||
|
LOG.info(f"max_input_len: {max_input_len}")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_lengths(dataset):
|
def get_dataset_lengths(dataset, from_arrow=False):
|
||||||
if "length" in dataset.column_names:
|
if "length" in dataset.column_names:
|
||||||
lengths = np.array(dataset["length"])
|
lengths = np.array(dataset["length"])
|
||||||
elif "position_ids" in dataset.column_names:
|
elif "position_ids" in dataset.column_names:
|
||||||
position_ids = dataset["position_ids"]
|
position_ids = dataset["position_ids"]
|
||||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||||
else:
|
else:
|
||||||
input_ids = dataset["input_ids"]
|
if from_arrow:
|
||||||
lengths = np.array([len(seq) for seq in input_ids])
|
input_ids = dataset.data.column("input_ids")
|
||||||
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||||
|
else:
|
||||||
|
input_ids = dataset["input_ids"]
|
||||||
|
lengths = np.array([len(seq) for seq in input_ids])
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
Reference in New Issue
Block a user