Default include_tkps to true (#3134)
* default true * force e2e * causal trainer only * fix eval loggin [skip-ci] * revert setup.py * force tests * guarding * guarding * fix test case * use evaluate [skip-e2e] * use evaluate [skip-e2e] * kick off ci * fixing * reverting
This commit is contained in:
@@ -36,7 +36,6 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
@@ -145,12 +144,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.cfg.include_tkps:
|
|
||||||
callbacks.append(
|
|
||||||
TokensPerSecondCallback(
|
|
||||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from axolotl.utils.collators import (
|
|||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -71,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.qat:
|
if self.cfg.qat:
|
||||||
callbacks.append(QATCallback(self.cfg.qat))
|
callbacks.append(QATCallback(self.cfg.qat))
|
||||||
|
|
||||||
|
if self.cfg.include_tkps:
|
||||||
|
callbacks.append(
|
||||||
|
TokensPerSecondCallback(
|
||||||
|
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||||
|
)
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
|||||||
@@ -342,10 +342,10 @@ class AxolotlTrainer(
|
|||||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||||
if hasattr(self.state, "num_tokens"):
|
if hasattr(self.state, "num_tokens"):
|
||||||
self.state.num_tokens = (
|
self.state.num_tokens = (
|
||||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
|
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
|
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
|
|||||||
@@ -43,11 +43,12 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
step_time = time.perf_counter() - self.start_time
|
if hasattr(state, "num_tokens"):
|
||||||
num_tokens_per_device = state.num_tokens.clone()
|
step_time = time.perf_counter() - self.start_time
|
||||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
num_tokens_per_device = state.num_tokens.clone()
|
||||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||||
|
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||||
|
|
||||||
def on_log(
|
def on_log(
|
||||||
self,
|
self,
|
||||||
@@ -58,5 +59,6 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# after logging, clear the running metrics
|
# after logging, clear the running metrics
|
||||||
state.last_tokens_per_second.zero_()
|
if hasattr(state, "last_tokens_per_second"):
|
||||||
state.num_tokens = 0
|
state.last_tokens_per_second.zero_()
|
||||||
|
state.num_tokens = torch.zeros(1)
|
||||||
|
|||||||
@@ -855,9 +855,9 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
include_tkps: bool | None = Field(
|
include_tkps: bool | None = Field(
|
||||||
default=None,
|
default=True,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
|
"description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
neftune_noise_alpha: float | None = Field(
|
neftune_noise_alpha: float | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user