Text Generation
Transformers
Safetensors
PyTorch
nvidia

Update modeling_nemotron_h.py

#2
by jaeminh - opened
Files changed (1) hide show
  1. modeling_nemotron_h.py +2 -1
modeling_nemotron_h.py CHANGED
@@ -852,7 +852,8 @@ class NemotronHMOE(nn.Module):
852
  final_hidden_states.index_add_(0, token_indices, weighted_output)
853
  else:
854
  # Local empty expert: no-op compute that still marks params as used.
855
- dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(final_hidden_states.dtype))
 
856
  final_hidden_states = final_hidden_states + dummy_out
857
 
858
  # in original deepseek, the output of the experts are gathered once we leave this module
 
852
  final_hidden_states.index_add_(0, token_indices, weighted_output)
853
  else:
854
  # Local empty expert: no-op compute that still marks params as used.
855
+ expert_dtype = expert.down_proj.weight.dtype
856
+ dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
857
  final_hidden_states = final_hidden_states + dummy_out
858
 
859
  # in original deepseek, the output of the experts are gathered once we leave this module