fix trl trainer.log interfaces
This commit is contained in:
@@ -1157,6 +1157,16 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1165,6 +1175,16 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1173,6 +1193,43 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# train metrics should have no prefix, eval should have 'eval_'
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
# accumulate average metrics from sums and lengths
|
||||||
|
for split in ["chosen", "rejected"]:
|
||||||
|
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||||
|
count_sum = (
|
||||||
|
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
for metric in ["rewards", "logps", "logits"]:
|
||||||
|
logs[f"{prefix}{metric}/{split}"] = (
|
||||||
|
torch.Tensor(
|
||||||
|
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
/ count_sum
|
||||||
|
)
|
||||||
|
# delete obsolete metric
|
||||||
|
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||||
|
# calculate reward margin
|
||||||
|
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||||
|
logs[f"{prefix}rewards/margins"] = (
|
||||||
|
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||||
|
)
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1181,6 +1238,16 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1189,6 +1256,16 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user