From 825f66b9fd86f5d95a4694f591a201ba06661619 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Dec 2024 14:52:59 -0500 Subject: [PATCH] update HF HUB env var and fix reward trainer log since it doesn't directly override log --- docker/Dockerfile-cloud | 2 +- docker/Dockerfile-cloud-no-tmux | 2 +- src/axolotl/core/trainer_builder.py | 6 ------ 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index d7e3277d2..c8249cb79 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -2,7 +2,7 @@ ARG BASE_TAG=main FROM axolotlai/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" -ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" +ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux index 6dfea4677..165063105 100644 --- a/docker/Dockerfile-cloud-no-tmux +++ b/docker/Dockerfile-cloud-no-tmux @@ -2,7 +2,7 @@ ARG BASE_TAG=main FROM axolotlai/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" -ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" +ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5418e53bd..baac94da8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1266,12 +1266,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: # TODO remove once trl supports the updated to the Trainer.log method - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] return super(RewardTrainer, self).log( # pylint: disable=bad-super-call logs, start_time )