Skip to content

Conversation

dongmin-ra
Copy link

@dongmin-ra dongmin-ra commented Oct 2, 2025

Motivation

Fixed an intermittent issue where the results of dispatch would be corrupted when running dispatch/combine repeatedly

  • The condition for this issue to occur is that there must be no global barrier (i.e. torch.distributed.barrier()) between dispatch and combine.

Technical Details

Issue

  • After integrating mori-EP into vLLM (refer), an intermittent GPU memory access fault occurred during multi-node EP.
    • This happened because the expert indices after dispatch were corrupted on some ranks.
  • After investigating it, I figured out that the dispatch result could be corrupted in internode EP.

Cause

  • Combine and dispatch share the same input buffer, shmemInpTokMemObj.
    • In internode dispatch, the last warp, after collecting data from remaining warps, sends data to the shmemInpTokMemObj buffer on remote GPU.
      • During the recv phase, data is copied from local shmemInpTokMemObj to local shmemOutTokMemObj.
    • In combine’s send phase, similar to dispatch, The last warp sends data to shmemInpTokMemObj on the remote GPU.
  • If combine starts immediately after dispatch on a fast GPU, its send phase may perform RDMA writes to shmemInpTokMemObj before dispatch’s recv phase has completed the memcpy.
    • This can overwrite the buffer and cause data corruption.

Fix

  • Separated the input buffers used by dispatch and combine.

Test Plan

pytest ./tests/python/ops/test_dispatch_combine_internode_inconsistency.py -s

This should be tested on a single node. Internally, the MORI_DISABLE_P2P environment variable is enabled to force communication via RDMA within the single node.

Test Result

  • Before modification : incorrect expert index values are produced as the dispatch result.
=================================================================================================================================================== test session starts ====================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
rootdir: /app/mori
plugins: assume-2.4.3, anyio-4.9.0, asyncio-1.0.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ...
collected 1 item

tests/python/ops/test_dispatch_combine_internode_inconsistency.py Multiprocessing start method set to spawn

rank 0 RDMA devices: mlx5_0, mlx5_2, mlx5_3, mlx5_4
rank 0 rankInNode 0 select device [0] mlx5_0
rank 3 rankInNode 3 select device [3] mlx5_4
rank 6 rankInNode 6 select device [0] mlx5_0
rank 5 rankInNode 5 select device [3] mlx5_4
rank 7 rankInNode 7 select device [1] mlx5_2
rank 2 rankInNode 2 select device [2] mlx5_3
rank 1 rankInNode 1 select device [1] mlx5_2
rank 4 rankInNode 4 select device [2] mlx5_3
Passed 0/2048
Passed 1/2048
Passed 2/2048
...
Passed 33/2048
Passed 34/2048
Passed 35/2048
Invalid expert id: 1261946812
  • After modification : no error occurs

Submission Checklist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant