Files
replaysim/patches/frontier-vllm-0.11.1-profiling-compat.patch

180 lines
7.2 KiB
Diff

diff --git a/frontier/profiling/common/layers/rotary_embedding.py b/frontier/profiling/common/layers/rotary_embedding.py
index 3f6d999..00be87b 100644
--- a/frontier/profiling/common/layers/rotary_embedding.py
+++ b/frontier/profiling/common/layers/rotary_embedding.py
@@ -576,15 +576,19 @@ def get_rope(
if not _should_prefer_torch_rope_fallback():
vllm_get_rope = _load_vllm_get_rope()
if vllm_get_rope is not None:
- return vllm_get_rope(
- head_size=head_size,
- rotary_dim=rotary_dim,
- max_position=max_position,
- base=base,
- is_neox_style=is_neox_style,
- rope_scaling=rope_scaling,
- dtype=rope_dtype,
- )
+ try:
+ return vllm_get_rope(
+ head_size=head_size,
+ rotary_dim=rotary_dim,
+ max_position=max_position,
+ base=base,
+ is_neox_style=is_neox_style,
+ rope_scaling=rope_scaling,
+ dtype=rope_dtype,
+ )
+ except TypeError as exc:
+ if "unexpected keyword argument" not in str(exc):
+ raise
if cache_key in _LOCAL_ROPE_DICT:
return _LOCAL_ROPE_DICT[cache_key]
diff --git a/frontier/profiling/moe/moe_impl.py b/frontier/profiling/moe/moe_impl.py
index f732980..79aed30 100644
--- a/frontier/profiling/moe/moe_impl.py
+++ b/frontier/profiling/moe/moe_impl.py
@@ -27,9 +27,16 @@ from frontier.profiling.common.utils import raise_if_fp8_requested
try:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
- get_config_dtype_str,
try_get_optimal_moe_config,
)
+ try:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ get_config_dtype_str,
+ )
+ except ImportError:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ _get_config_dtype_str as get_config_dtype_str,
+ )
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
@@ -128,14 +135,20 @@ class MoEGatingNetwork(nn.Module):
)
if self.use_vllm_fused_topk and HAS_VLLM_REPLICATED_LINEAR:
- # Align gating linear kernel family with vLLM runtime contract.
- # disable_tp=True avoids requiring TP group initialization in profiling jobs.
- self.gate = ReplicatedLinear(
- hidden_dim,
- num_experts,
- bias=False,
- disable_tp=True,
- )
+ try:
+ # Align gating linear kernel family with vLLM runtime contract.
+ # vLLM 0.11.x still touches TP state even with disable_tp=True in
+ # standalone profiling, so fall back to torch Linear if needed.
+ self.gate = ReplicatedLinear(
+ hidden_dim,
+ num_experts,
+ bias=False,
+ disable_tp=True,
+ )
+ except AssertionError as exc:
+ if "tensor model parallel group is not initialized" not in str(exc):
+ raise
+ self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
else:
# Fall back to native torch linear only when vLLM kernel alignment is disabled.
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
@@ -187,13 +200,14 @@ class MoEGatingNetwork(nn.Module):
indices_type=None,
)
else:
- routing_weights, selected_experts, _ = fused_topk(
+ fused_topk_outputs = fused_topk(
hidden_states=hidden_states,
gating_output=logits,
topk=self.router_topk,
renormalize=getattr(self, "renormalize", True),
indices_type=None,
)
+ routing_weights, selected_experts = fused_topk_outputs[:2]
else:
if routing_runtime_path != "standard_fused_topk":
raise ValueError(
diff --git a/frontier/profiling/moe/moe_vllm_kernel.py b/frontier/profiling/moe/moe_vllm_kernel.py
index 7228731..726c748 100644
--- a/frontier/profiling/moe/moe_vllm_kernel.py
+++ b/frontier/profiling/moe/moe_vllm_kernel.py
@@ -36,8 +36,15 @@ try:
invoke_fused_moe_kernel,
moe_align_block_size,
try_get_optimal_moe_config,
- get_config_dtype_str,
)
+ try:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ get_config_dtype_str,
+ )
+ except ImportError:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ _get_config_dtype_str as get_config_dtype_str,
+ )
VLLM_API_VERSION = "0.10.x"
VLLM_AVAILABLE = True
@@ -195,6 +202,7 @@ def _invoke_kernel(
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -249,6 +257,7 @@ def _invoke_kernel(
B_scale=B_scale,
B_zp=None,
topk_weights=topk_weights,
+ topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
@@ -260,7 +269,9 @@ def _invoke_kernel(
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
+ use_int4_w4a8=False,
per_channel_quant=per_channel_quant,
+ use_valu=False,
block_shape=block_shape,
B_bias=None,
)
@@ -273,6 +284,7 @@ def _run_fused_moe_iteration(
intermediate_cache1: torch.Tensor,
intermediate_cache2: torch.Tensor,
topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -292,6 +304,7 @@ def _run_fused_moe_iteration(
B=w1.contiguous(),
C=intermediate_cache1.contiguous(),
topk_weights=topk_weights.contiguous(),
+ topk_ids=topk_ids.contiguous(),
sorted_token_ids=sorted_token_ids.contiguous(),
expert_ids=expert_ids.contiguous(),
num_tokens_post_padded=num_tokens_post_padded.contiguous(),
@@ -321,6 +334,7 @@ def _run_fused_moe_iteration(
B=w2.contiguous(),
C=intermediate_cache2.contiguous(),
topk_weights=topk_weights.contiguous(),
+ topk_ids=topk_ids.contiguous(),
sorted_token_ids=sorted_token_ids.contiguous(),
expert_ids=expert_ids.contiguous(),
num_tokens_post_padded=num_tokens_post_padded.contiguous(),
@@ -548,6 +562,7 @@ def profile_fused_moe_kernel(
intermediate_cache1=intermediate_cache1,
intermediate_cache2=intermediate_cache2,
topk_weights=topk_weights,
+ topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,