precommit

This commit is contained in:
Dan Saunders
2025-03-21 11:40:48 -04:00
committed by Dan Saunders
parent ab3b36339a
commit ce35b2a95f
6 changed files with 11 additions and 9 deletions

View File

@@ -758,13 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs[
"sequence_parallel_degree"
] = self.cfg.sequence_parallel_degree
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig

View File

@@ -1,4 +1,5 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa

View File

@@ -1,4 +1,5 @@
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
@@ -114,7 +115,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
drop_last=True,
)
@override
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
@@ -146,7 +146,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return base_sampler
@override
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
@@ -591,7 +590,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
@override
def training_step(
self,
model: nn.Module,

View File

@@ -1,4 +1,5 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa

View File

@@ -1,4 +1,5 @@
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging

View File

@@ -1,4 +1,5 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
from unittest.mock import MagicMock, patch