180 lines
7.2 KiB
Diff
180 lines
7.2 KiB
Diff
diff --git a/frontier/profiling/common/layers/rotary_embedding.py b/frontier/profiling/common/layers/rotary_embedding.py
|
|
index 3f6d999..00be87b 100644
|
|
--- a/frontier/profiling/common/layers/rotary_embedding.py
|
|
+++ b/frontier/profiling/common/layers/rotary_embedding.py
|
|
@@ -576,15 +576,19 @@ def get_rope(
|
|
if not _should_prefer_torch_rope_fallback():
|
|
vllm_get_rope = _load_vllm_get_rope()
|
|
if vllm_get_rope is not None:
|
|
- return vllm_get_rope(
|
|
- head_size=head_size,
|
|
- rotary_dim=rotary_dim,
|
|
- max_position=max_position,
|
|
- base=base,
|
|
- is_neox_style=is_neox_style,
|
|
- rope_scaling=rope_scaling,
|
|
- dtype=rope_dtype,
|
|
- )
|
|
+ try:
|
|
+ return vllm_get_rope(
|
|
+ head_size=head_size,
|
|
+ rotary_dim=rotary_dim,
|
|
+ max_position=max_position,
|
|
+ base=base,
|
|
+ is_neox_style=is_neox_style,
|
|
+ rope_scaling=rope_scaling,
|
|
+ dtype=rope_dtype,
|
|
+ )
|
|
+ except TypeError as exc:
|
|
+ if "unexpected keyword argument" not in str(exc):
|
|
+ raise
|
|
|
|
if cache_key in _LOCAL_ROPE_DICT:
|
|
return _LOCAL_ROPE_DICT[cache_key]
|
|
diff --git a/frontier/profiling/moe/moe_impl.py b/frontier/profiling/moe/moe_impl.py
|
|
index f732980..79aed30 100644
|
|
--- a/frontier/profiling/moe/moe_impl.py
|
|
+++ b/frontier/profiling/moe/moe_impl.py
|
|
@@ -27,9 +27,16 @@ from frontier.profiling.common.utils import raise_if_fp8_requested
|
|
try:
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
fused_topk,
|
|
- get_config_dtype_str,
|
|
try_get_optimal_moe_config,
|
|
)
|
|
+ try:
|
|
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
+ get_config_dtype_str,
|
|
+ )
|
|
+ except ImportError:
|
|
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
+ _get_config_dtype_str as get_config_dtype_str,
|
|
+ )
|
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|
moe_align_block_size,
|
|
)
|
|
@@ -128,14 +135,20 @@ class MoEGatingNetwork(nn.Module):
|
|
)
|
|
|
|
if self.use_vllm_fused_topk and HAS_VLLM_REPLICATED_LINEAR:
|
|
- # Align gating linear kernel family with vLLM runtime contract.
|
|
- # disable_tp=True avoids requiring TP group initialization in profiling jobs.
|
|
- self.gate = ReplicatedLinear(
|
|
- hidden_dim,
|
|
- num_experts,
|
|
- bias=False,
|
|
- disable_tp=True,
|
|
- )
|
|
+ try:
|
|
+ # Align gating linear kernel family with vLLM runtime contract.
|
|
+ # vLLM 0.11.x still touches TP state even with disable_tp=True in
|
|
+ # standalone profiling, so fall back to torch Linear if needed.
|
|
+ self.gate = ReplicatedLinear(
|
|
+ hidden_dim,
|
|
+ num_experts,
|
|
+ bias=False,
|
|
+ disable_tp=True,
|
|
+ )
|
|
+ except AssertionError as exc:
|
|
+ if "tensor model parallel group is not initialized" not in str(exc):
|
|
+ raise
|
|
+ self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
|
else:
|
|
# Fall back to native torch linear only when vLLM kernel alignment is disabled.
|
|
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
|
@@ -187,13 +200,14 @@ class MoEGatingNetwork(nn.Module):
|
|
indices_type=None,
|
|
)
|
|
else:
|
|
- routing_weights, selected_experts, _ = fused_topk(
|
|
+ fused_topk_outputs = fused_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=logits,
|
|
topk=self.router_topk,
|
|
renormalize=getattr(self, "renormalize", True),
|
|
indices_type=None,
|
|
)
|
|
+ routing_weights, selected_experts = fused_topk_outputs[:2]
|
|
else:
|
|
if routing_runtime_path != "standard_fused_topk":
|
|
raise ValueError(
|
|
diff --git a/frontier/profiling/moe/moe_vllm_kernel.py b/frontier/profiling/moe/moe_vllm_kernel.py
|
|
index 7228731..726c748 100644
|
|
--- a/frontier/profiling/moe/moe_vllm_kernel.py
|
|
+++ b/frontier/profiling/moe/moe_vllm_kernel.py
|
|
@@ -36,8 +36,15 @@ try:
|
|
invoke_fused_moe_kernel,
|
|
moe_align_block_size,
|
|
try_get_optimal_moe_config,
|
|
- get_config_dtype_str,
|
|
)
|
|
+ try:
|
|
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
+ get_config_dtype_str,
|
|
+ )
|
|
+ except ImportError:
|
|
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
+ _get_config_dtype_str as get_config_dtype_str,
|
|
+ )
|
|
|
|
VLLM_API_VERSION = "0.10.x"
|
|
VLLM_AVAILABLE = True
|
|
@@ -195,6 +202,7 @@ def _invoke_kernel(
|
|
B: torch.Tensor,
|
|
C: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
+ topk_ids: torch.Tensor,
|
|
sorted_token_ids: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
num_tokens_post_padded: torch.Tensor,
|
|
@@ -249,6 +257,7 @@ def _invoke_kernel(
|
|
B_scale=B_scale,
|
|
B_zp=None,
|
|
topk_weights=topk_weights,
|
|
+ topk_ids=topk_ids,
|
|
sorted_token_ids=sorted_token_ids,
|
|
expert_ids=expert_ids,
|
|
num_tokens_post_padded=num_tokens_post_padded,
|
|
@@ -260,7 +269,9 @@ def _invoke_kernel(
|
|
use_int8_w8a8=False,
|
|
use_int8_w8a16=False,
|
|
use_int4_w4a16=False,
|
|
+ use_int4_w4a8=False,
|
|
per_channel_quant=per_channel_quant,
|
|
+ use_valu=False,
|
|
block_shape=block_shape,
|
|
B_bias=None,
|
|
)
|
|
@@ -273,6 +284,7 @@ def _run_fused_moe_iteration(
|
|
intermediate_cache1: torch.Tensor,
|
|
intermediate_cache2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
+ topk_ids: torch.Tensor,
|
|
sorted_token_ids: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
num_tokens_post_padded: torch.Tensor,
|
|
@@ -292,6 +304,7 @@ def _run_fused_moe_iteration(
|
|
B=w1.contiguous(),
|
|
C=intermediate_cache1.contiguous(),
|
|
topk_weights=topk_weights.contiguous(),
|
|
+ topk_ids=topk_ids.contiguous(),
|
|
sorted_token_ids=sorted_token_ids.contiguous(),
|
|
expert_ids=expert_ids.contiguous(),
|
|
num_tokens_post_padded=num_tokens_post_padded.contiguous(),
|
|
@@ -321,6 +334,7 @@ def _run_fused_moe_iteration(
|
|
B=w2.contiguous(),
|
|
C=intermediate_cache2.contiguous(),
|
|
topk_weights=topk_weights.contiguous(),
|
|
+ topk_ids=topk_ids.contiguous(),
|
|
sorted_token_ids=sorted_token_ids.contiguous(),
|
|
expert_ids=expert_ids.contiguous(),
|
|
num_tokens_post_padded=num_tokens_post_padded.contiguous(),
|
|
@@ -548,6 +562,7 @@ def profile_fused_moe_kernel(
|
|
intermediate_cache1=intermediate_cache1,
|
|
intermediate_cache2=intermediate_cache2,
|
|
topk_weights=topk_weights,
|
|
+ topk_ids=topk_ids,
|
|
sorted_token_ids=sorted_token_ids,
|
|
expert_ids=expert_ids,
|
|
num_tokens_post_padded=num_tokens_post_padded,
|