Update modeling_nemotron_h.py
#2
by
jaeminh
- opened
- modeling_nemotron_h.py +2 -1
modeling_nemotron_h.py
CHANGED
|
@@ -852,7 +852,8 @@ class NemotronHMOE(nn.Module):
|
|
| 852 |
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
| 853 |
else:
|
| 854 |
# Local empty expert: no-op compute that still marks params as used.
|
| 855 |
-
|
|
|
|
| 856 |
final_hidden_states = final_hidden_states + dummy_out
|
| 857 |
|
| 858 |
# in original deepseek, the output of the experts are gathered once we leave this module
|
|
|
|
| 852 |
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
| 853 |
else:
|
| 854 |
# Local empty expert: no-op compute that still marks params as used.
|
| 855 |
+
expert_dtype = expert.down_proj.weight.dtype
|
| 856 |
+
dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
|
| 857 |
final_hidden_states = final_hidden_states + dummy_out
|
| 858 |
|
| 859 |
# in original deepseek, the output of the experts are gathered once we leave this module
|