precommit

This commit is contained in:
Dan Saunders
2025-03-21 11:40:48 -04:00
committed by Dan Saunders
parent ab3b36339a
commit ce35b2a95f
6 changed files with 11 additions and 9 deletions

View File

@@ -758,13 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_zscore_base_temp self.cfg.kd_zscore_base_temp
) )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_top_k_before_softmax"] = (
"kd_top_k_before_softmax" self.cfg.kd_top_k_before_softmax
] = self.cfg.kd_top_k_before_softmax )
training_arguments_kwargs[ training_arguments_kwargs["sequence_parallel_degree"] = (
"sequence_parallel_degree" self.cfg.sequence_parallel_degree
] = self.cfg.sequence_parallel_degree )
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig

View File

@@ -1,4 +1,5 @@
"""Init for axolotl.core.trainers""" """Init for axolotl.core.trainers"""
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa

View File

@@ -1,4 +1,5 @@
"""Module for customized trainers""" """Module for customized trainers"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from __future__ import annotations from __future__ import annotations
@@ -114,7 +115,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
drop_last=True, drop_last=True,
) )
@override
def _get_train_sampler(self) -> Sampler | None: def _get_train_sampler(self) -> Sampler | None:
""" """
Helper method to get the sampler for training. Handles cases for sequence Helper method to get the sampler for training. Handles cases for sequence
@@ -146,7 +146,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return base_sampler return base_sampler
@override
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
""" """
Helper method to get the sampler for evaluation. Handles sequence parallelism Helper method to get the sampler for evaluation. Handles sequence parallelism
@@ -591,7 +590,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs) return super()._save_checkpoint(model, trial, **kwargs)
@override
def training_step( def training_step(
self, self,
model: nn.Module, model: nn.Module,

View File

@@ -1,4 +1,5 @@
"""Init for axolotl.core.trainers.mixins""" """Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa

View File

@@ -1,4 +1,5 @@
"""Module with Pydantic models for configuration.""" """Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging

View File

@@ -1,4 +1,5 @@
"""Tests for sequence parallelism functionality.""" """Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument # pylint: disable=redefined-outer-name,unused-argument
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch