Skip to content

Commit f3162fd

Browse files
committed
Performance benchmark
1 parent 742f0d7 commit f3162fd

File tree

6 files changed

+297
-29
lines changed

6 files changed

+297
-29
lines changed

pom.xml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ See file LICENSE for terms.
3535
<scala.version>2.12.12</scala.version>
3636
<scala.compat.version>2.12</scala.compat.version>
3737
<jucx.version>1.10.0-SNAPSHOT</jucx.version>
38+
<cudf.version>0.15</cudf.version>
3839
</properties>
3940

4041
<dependencies>
@@ -55,6 +56,12 @@ See file LICENSE for terms.
5556
<version>3.2.1</version>
5657
<scope>test</scope>
5758
</dependency>
59+
<dependency>
60+
<groupId>ai.rapids</groupId>
61+
<artifactId>cudf</artifactId>
62+
<version>${cudf.version}</version>
63+
<scope>test</scope>
64+
</dependency>
5865
</dependencies>
5966

6067
<build>
@@ -137,6 +144,18 @@ See file LICENSE for terms.
137144
</execution>
138145
</executions>
139146
</plugin>
147+
<plugin>
148+
<groupId>org.apache.maven.plugins</groupId>
149+
<artifactId>maven-jar-plugin</artifactId>
150+
<version>3.2.0</version>
151+
<executions>
152+
<execution>
153+
<goals>
154+
<goal>test-jar</goal>
155+
</goals>
156+
</execution>
157+
</executions>
158+
</plugin>
140159
</plugins>
141160
</build>
142161

src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.apache.spark.shuffle.ucx
66

77
import java.io.{Closeable, ObjectOutputStream}
8+
import java.nio.BufferOverflowException
89
import java.util.concurrent.ThreadLocalRandom
910

1011
import scala.collection.mutable
@@ -91,7 +92,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
9192
logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId")
9293
val endpointParams = new UcpEndpointParams()
9394
.setUcpAddress(workerAdresses.get(executorId))
94-
worker.newEndpoint(endpointParams)
95+
worker.newEndpoint(endpointParams)
9596
})
9697
}
9798

@@ -115,16 +116,16 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
115116

116117
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize,
117118
UcxRpcMessages.PREFETCH_TAG, new UcxCallback() {
118-
override def onSuccess(request: UcpRequest): Unit = {
119-
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
120-
memoryPool.put(mem)
121-
}
122-
})
119+
override def onSuccess(request: UcpRequest): Unit = {
120+
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
121+
memoryPool.put(mem)
122+
}
123+
})
123124
}
124125

125126
private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
126-
resultBuffer: Seq[MemoryBlock],
127-
callbacks: Seq[OperationCallback]): Seq[Request] = {
127+
resultBuffer: Seq[MemoryBlock],
128+
callbacks: Seq[OperationCallback]): Seq[Request] = {
128129
val ep = getConnection(executorId)
129130
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
130131
val buffer = UcxUtils.getByteBufferView(mem.address,
@@ -136,13 +137,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
136137

137138
Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
138139
val out = new ObjectOutputStream(bos)
139-
out.writeObject(message)
140-
out.flush()
141-
out.close()
140+
try {
141+
out.writeObject(message)
142+
out.flush()
143+
out.close()
144+
} catch {
145+
case _: BufferOverflowException =>
146+
throw new UcxException(s"Prefetch blocks message size > " +
147+
s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}")
148+
case ex: Exception => throw new UcxException(ex.getMessage)
149+
}
142150
}
143151

144152
val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0)
145-
logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
153+
logTrace(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
146154
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
147155
new UcxCallback() {
148156
override def onSuccess(request: UcpRequest): Unit = {
@@ -179,7 +187,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
179187
}
180188

181189
private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId,
182-
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
190+
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
183191
val stats = new UcxStats()
184192
val ep = getConnection(executorId)
185193
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
@@ -206,28 +214,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
206214
override def onSuccess(request: UcpRequest): Unit = {
207215
memoryPool.put(mem)
208216
}
209-
})
217+
})
210218

211219
val result = new UcxSuccessOperationResult(stats)
212220
val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size,
213221
tag, -1L, new UcxCallback () {
214222

215-
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
216-
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
217-
s" of size: ${resultBuffer.size}: $errorMsg")
218-
if (cb != null ) {
219-
cb.onComplete(new UcxFailureOperationResult(errorMsg))
223+
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
224+
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
225+
s" of size: ${resultBuffer.size}: $errorMsg")
226+
if (cb != null ) {
227+
cb.onComplete(new UcxFailureOperationResult(errorMsg))
228+
}
220229
}
221-
}
222230

223-
override def onSuccess(request: UcpRequest): Unit = {
224-
stats.endTime = System.nanoTime()
225-
stats.receiveSize = request.getRecvSize
226-
if (cb != null) {
227-
cb.onComplete(result)
231+
override def onSuccess(request: UcpRequest): Unit = {
232+
stats.endTime = System.nanoTime()
233+
stats.receiveSize = request.getRecvSize
234+
if (cb != null) {
235+
cb.onComplete(result)
236+
}
228237
}
229-
}
230-
})
238+
})
231239
new UcxRequest(request, stats)
232240
}
233241

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
3+
* See file LICENSE for terms.
4+
*/
5+
package org.apache.spark.shuffle.ucx.perf
6+
7+
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
8+
import java.nio.ByteBuffer
9+
import java.util.concurrent.TimeUnit
10+
import java.util.concurrent.atomic.AtomicInteger
11+
12+
import org.apache.commons.cli.{GnuParser, HelpFormatter, Options}
13+
import org.apache.spark.SparkConf
14+
import org.apache.spark.shuffle.ucx._
15+
import org.apache.spark.shuffle.ucx.memory.MemoryPool
16+
import org.apache.spark.util.Utils
17+
18+
object UcxShuffleTransportPerfTool {
19+
private val HELP_OPTION = "h"
20+
private val ADDRESS_OPTION = "a"
21+
private val NUM_BLOCKS_OPTION = "n"
22+
private val SIZE_OPTION = "s"
23+
private val PORT_OPTION = "p"
24+
private val ITER_OPTION = "i"
25+
private val MEMORY_TYPE_OPTION = "m"
26+
private val NUM_THREADS_OPTION = "t"
27+
28+
private val ucxShuffleConf = new UcxShuffleConf(new SparkConf())
29+
private val transport = new UcxShuffleTransport(ucxShuffleConf, "e")
30+
private val workerAddress = transport.init()
31+
private var memoryPool: MemoryPool = transport.memoryPool
32+
33+
case class TestBlockId(id: Int) extends BlockId
34+
35+
case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long,
36+
serverPort: Int, numIterations: Int, numThreads: Int)
37+
38+
private def initOptions(): Options = {
39+
val options = new Options()
40+
options.addOption(HELP_OPTION, "help", false,
41+
"display help message")
42+
options.addOption(ADDRESS_OPTION, "address", true,
43+
"address of the remote host")
44+
options.addOption(NUM_BLOCKS_OPTION, "num-blocks", true,
45+
"number of blocks to transfer. Default: 1")
46+
options.addOption(SIZE_OPTION, "block-size", true,
47+
"size of block to transfer. Default: 4m")
48+
options.addOption(PORT_OPTION, "server-port", true,
49+
"server port. Default: 12345")
50+
options.addOption(ITER_OPTION, "num-iterations", true,
51+
"number of iterations. Default: 5")
52+
options.addOption(NUM_THREADS_OPTION, "num-threads", true,
53+
"number of threads. Default: 1")
54+
options.addOption(MEMORY_TYPE_OPTION, "memory-type", true,
55+
"memory type: host (default), cuda")
56+
}
57+
58+
private def parseOptions(args: Array[String]): PerfOptions = {
59+
val parser = new GnuParser()
60+
val options = initOptions()
61+
val cmd = parser.parse(options, args)
62+
63+
if (cmd.hasOption(HELP_OPTION)) {
64+
new HelpFormatter().printHelp("UcxShufflePerfTool", options)
65+
System.exit(0)
66+
}
67+
68+
val inetAddress = if (cmd.hasOption(ADDRESS_OPTION)) {
69+
val Array(host, port) = cmd.getOptionValue(ADDRESS_OPTION).split(":")
70+
new InetSocketAddress(host, Integer.parseInt(port))
71+
} else {
72+
null
73+
}
74+
75+
val serverPort = Integer.parseInt(cmd.getOptionValue(PORT_OPTION, "12345"))
76+
77+
val numIterations = Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "5"))
78+
79+
val threadsNumber = Integer.parseInt(cmd.getOptionValue(NUM_THREADS_OPTION, "1"))
80+
81+
if (cmd.hasOption(MEMORY_TYPE_OPTION) && cmd.getOptionValue(MEMORY_TYPE_OPTION) == "cuda") {
82+
val className = "org.apache.spark.shuffle.ucx.GpuMemoryPool"
83+
val cls = Utils.classForName(className)
84+
memoryPool = cls.getConstructor().newInstance().asInstanceOf[MemoryPool]
85+
}
86+
87+
PerfOptions(inetAddress,
88+
Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")),
89+
Utils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "4m")),
90+
serverPort, numIterations, threadsNumber)
91+
}
92+
93+
private def startServer(perfOptions: PerfOptions): Unit = {
94+
val blocks: Seq[Block] = (0 until perfOptions.numBlocks).map { _ =>
95+
val block = memoryPool.get(perfOptions.blockSize)
96+
new Block {
97+
override def getMemoryBlock: MemoryBlock =
98+
block
99+
}
100+
}
101+
102+
val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i))
103+
blockIds.zip(blocks).foreach {
104+
case (blockId, block) => transport.register(blockId, block)
105+
}
106+
107+
val serverSocket = new ServerSocket(perfOptions.serverPort)
108+
109+
println(s"Waiting for connections on " +
110+
s"${InetAddress.getLocalHost.getHostName}:${perfOptions.serverPort} ")
111+
112+
val clientSocket = serverSocket.accept()
113+
val out = clientSocket.getOutputStream
114+
val in = clientSocket.getInputStream
115+
116+
val buf = ByteBuffer.allocate(workerAddress.capacity())
117+
buf.put(workerAddress)
118+
buf.flip()
119+
120+
out.write(buf.array())
121+
out.flush()
122+
123+
println(s"Sending worker address to ${clientSocket.getInetAddress}")
124+
125+
buf.flip()
126+
127+
in.read(buf.array())
128+
clientSocket.close()
129+
serverSocket.close()
130+
131+
blocks.foreach(block => memoryPool.put(block.getMemoryBlock))
132+
blockIds.foreach(transport.unregister)
133+
transport.close()
134+
}
135+
136+
private def startClient(perfOptions: PerfOptions): Unit = {
137+
val socket = new Socket(perfOptions.remoteAddress.getHostName,
138+
perfOptions.remoteAddress.getPort)
139+
140+
val buf = new Array[Byte](4096)
141+
val readSize = socket.getInputStream.read(buf)
142+
val executorId = "1"
143+
val workerAddress = ByteBuffer.allocateDirect(readSize)
144+
145+
workerAddress.put(buf, 0, readSize)
146+
println("Received worker address")
147+
148+
transport.addExecutor(executorId, workerAddress)
149+
150+
val resultSize = perfOptions.numBlocks * perfOptions.blockSize
151+
val resultMemory = memoryPool.get(resultSize)
152+
153+
val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i))
154+
155+
(1 to perfOptions.numIterations).foreach { (i: Int) => {
156+
(0 until perfOptions.numThreads).par.foreach {(t: Int) => {
157+
val completed = new AtomicInteger(0)
158+
var elapsedTime: Long = 0L
159+
160+
val mem = new Array[MemoryBlock](perfOptions.numBlocks)
161+
val callbacks = new Array[OperationCallback](perfOptions.numBlocks)
162+
(0 until perfOptions.numBlocks).foreach(j => {
163+
mem(j) = MemoryBlock(resultMemory.address + j * perfOptions.blockSize, perfOptions.blockSize)
164+
callbacks(j) = (result: OperationResult) => {
165+
elapsedTime += result.getStats.get.getElapsedTimeNs
166+
completed.incrementAndGet()
167+
}
168+
})
169+
170+
transport.fetchBlocksByBlockIds(executorId, blockIds, mem, callbacks)
171+
172+
while (completed.get() != perfOptions.numBlocks) {
173+
transport.progress()
174+
}
175+
val totalTime = if (elapsedTime < TimeUnit.MILLISECONDS.toNanos(1)) {
176+
s"$elapsedTime ns"
177+
} else {
178+
s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime)} ms"
179+
}
180+
val throughput: Double = (resultSize / 1024.0D / 1024.0D / 1024.0D) / (elapsedTime / 1e9D)
181+
182+
println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" +
183+
s" numBlocks: ${perfOptions.numBlocks}" +
184+
s" size: ${Utils.bytesToString(perfOptions.blockSize)}," +
185+
s" total size: ${Utils.bytesToString(resultSize)}," +
186+
f" time: $totalTime%3s" +
187+
f" throughput: $throughput%.3f GB/s")
188+
}}
189+
}}
190+
191+
val out = socket.getOutputStream
192+
out.write(buf)
193+
out.flush()
194+
out.close()
195+
socket.close()
196+
197+
memoryPool.put(resultMemory)
198+
transport.close()
199+
}
200+
201+
def main(args: Array[String]): Unit = {
202+
val perfOptions = parseOptions(args)
203+
204+
if (perfOptions.remoteAddress == null) {
205+
startServer(perfOptions)
206+
} else {
207+
startClient(perfOptions)
208+
}
209+
}
210+
211+
}

src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool,
2121
setDaemon(true)
2222
setName("Ucx Shuffle Transport Progress Thread")
2323

24-
24+
2525
override def run(): Unit = {
2626
val numRecvs = transport.ucxShuffleConf.recvQueueSize
2727
val msgSize = transport.ucxShuffleConf.rpcMessageSize

0 commit comments

Comments
 (0)