precommit
This commit is contained in:
committed by
Dan Saunders
parent
ab3b36339a
commit
ce35b2a95f
@@ -758,13 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs[
|
||||
"kd_top_k_before_softmax"
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
training_arguments_kwargs[
|
||||
"sequence_parallel_degree"
|
||||
] = self.cfg.sequence_parallel_degree
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Init for axolotl.core.trainers"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for customized trainers"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -114,7 +115,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sequence
|
||||
@@ -146,7 +146,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
|
||||
return base_sampler
|
||||
|
||||
@override
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||
@@ -591,7 +590,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
@override
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Init for axolotl.core.trainers.mixins"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
Reference in New Issue
Block a user