train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count * guard for none n zero
This commit is contained in:
@@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
tokens = getattr(state, "tokens", None)
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
step_time = time.perf_counter() - self.start_time
|
||||
num_tokens_per_device = tokens["trainable_tokens"].clone()
|
||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||
if not (tokens and "trainable_tokens" in tokens):
|
||||
return
|
||||
step_time = time.perf_counter() - self.start_time
|
||||
if step_time <= 0:
|
||||
return
|
||||
|
||||
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
|
||||
if torch.distributed.is_initialized():
|
||||
dp_size = max(
|
||||
1, torch.distributed.get_world_size() // self.non_data_parallel_size
|
||||
)
|
||||
num_tokens = num_tokens / dp_size
|
||||
state.last_tokens_per_second = num_tokens / step_time
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user