precommit
This commit is contained in:
committed by
Dan Saunders
parent
ab3b36339a
commit
ce35b2a95f
@@ -758,13 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_zscore_base_temp
|
self.cfg.kd_zscore_base_temp
|
||||||
)
|
)
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
if self.cfg.kd_top_k_before_softmax is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||||
"kd_top_k_before_softmax"
|
self.cfg.kd_top_k_before_softmax
|
||||||
] = self.cfg.kd_top_k_before_softmax
|
)
|
||||||
|
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
"sequence_parallel_degree"
|
self.cfg.sequence_parallel_degree
|
||||||
] = self.cfg.sequence_parallel_degree
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Init for axolotl.core.trainers"""
|
"""Init for axolotl.core.trainers"""
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for customized trainers"""
|
"""Module for customized trainers"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -114,7 +115,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
@@ -146,7 +146,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
|
|
||||||
return base_sampler
|
return base_sampler
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
@@ -591,7 +590,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
@override
|
|
||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Init for axolotl.core.trainers.mixins"""
|
"""Init for axolotl.core.trainers.mixins"""
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module with Pydantic models for configuration."""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for sequence parallelism functionality."""
|
"""Tests for sequence parallelism functionality."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name,unused-argument
|
# pylint: disable=redefined-outer-name,unused-argument
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|||||||
Reference in New Issue
Block a user