Sample pack trust remote code v2 (#1873)
* fix the multipack patch for remote code models * add deepseek v2 lite example w fsdp
This commit is contained in:
@@ -94,3 +94,5 @@ def patch_remote(model_name, config_name, modeling_name):
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
# workaround to make the patch stick
|
||||
modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access
|
||||
|
||||
@@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
max_num = int(torch.max(attention_mask).item())
|
||||
batch_size, _ = attention_mask.shape
|
||||
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||
|
||||
for i in range(1, max_num + 1):
|
||||
mask = attention_mask == i
|
||||
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||
|
||||
result = counts.flatten()
|
||||
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||
return result[nonzero_indices]
|
||||
|
||||
Reference in New Issue
Block a user