Skip to content

Commit 3bb69db

Browse files
committed
Performance benchmark
1 parent 742f0d7 commit 3bb69db

File tree

7 files changed

+312
-31
lines changed

7 files changed

+312
-31
lines changed

pom.xml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ See file LICENSE for terms.
3535
<scala.version>2.12.12</scala.version>
3636
<scala.compat.version>2.12</scala.compat.version>
3737
<jucx.version>1.10.0-SNAPSHOT</jucx.version>
38+
<cudf.version>0.15</cudf.version>
3839
</properties>
3940

4041
<dependencies>
@@ -55,6 +56,12 @@ See file LICENSE for terms.
5556
<version>3.2.1</version>
5657
<scope>test</scope>
5758
</dependency>
59+
<dependency>
60+
<groupId>ai.rapids</groupId>
61+
<artifactId>cudf</artifactId>
62+
<version>${cudf.version}</version>
63+
<scope>test</scope>
64+
</dependency>
5865
</dependencies>
5966

6067
<build>
@@ -137,6 +144,18 @@ See file LICENSE for terms.
137144
</execution>
138145
</executions>
139146
</plugin>
147+
<plugin>
148+
<groupId>org.apache.maven.plugins</groupId>
149+
<artifactId>maven-jar-plugin</artifactId>
150+
<version>3.2.0</version>
151+
<executions>
152+
<execution>
153+
<goals>
154+
<goal>test-jar</goal>
155+
</goals>
156+
</execution>
157+
</executions>
158+
</plugin>
140159
</plugins>
141160
</build>
142161

src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
153153
private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer,
154154
blockIds: Seq[BlockId]) {
155155

156-
logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}")
156+
logDebug(s"Prefetching blocks: ${blockIds.mkString(",")}")
157157
clientConnections.getOrElseUpdate(workerId,
158158
globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress))
159159
)
@@ -184,7 +184,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
184184
lock.lock()
185185
val blockMemory = block.getMemoryBlock
186186

187-
logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag")
187+
logDebug(s"Sending $blockId of size ${blockMemory.size} to tag: $tag")
188188
ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback {
189189
override def onSuccess(request: UcpRequest): Unit = {
190190
if (block.isInstanceOf[UcxPinnedBlock]) {

src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.apache.spark.shuffle.ucx
66

77
import java.io.{Closeable, ObjectOutputStream}
8+
import java.nio.BufferOverflowException
89
import java.util.concurrent.ThreadLocalRandom
910

1011
import scala.collection.mutable
@@ -91,7 +92,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
9192
logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId")
9293
val endpointParams = new UcpEndpointParams()
9394
.setUcpAddress(workerAdresses.get(executorId))
94-
worker.newEndpoint(endpointParams)
95+
worker.newEndpoint(endpointParams)
9596
})
9697
}
9798

@@ -115,16 +116,16 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
115116

116117
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize,
117118
UcxRpcMessages.PREFETCH_TAG, new UcxCallback() {
118-
override def onSuccess(request: UcpRequest): Unit = {
119-
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
120-
memoryPool.put(mem)
121-
}
122-
})
119+
override def onSuccess(request: UcpRequest): Unit = {
120+
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
121+
memoryPool.put(mem)
122+
}
123+
})
123124
}
124125

125126
private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
126-
resultBuffer: Seq[MemoryBlock],
127-
callbacks: Seq[OperationCallback]): Seq[Request] = {
127+
resultBuffer: Seq[MemoryBlock],
128+
callbacks: Seq[OperationCallback]): Seq[Request] = {
128129
val ep = getConnection(executorId)
129130
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
130131
val buffer = UcxUtils.getByteBufferView(mem.address,
@@ -136,13 +137,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
136137

137138
Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
138139
val out = new ObjectOutputStream(bos)
139-
out.writeObject(message)
140-
out.flush()
141-
out.close()
140+
try {
141+
out.writeObject(message)
142+
out.flush()
143+
out.close()
144+
} catch {
145+
case _: BufferOverflowException =>
146+
throw new UcxException(s"Prefetch blocks message size > " +
147+
s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}")
148+
case ex: Exception => throw new UcxException(ex.getMessage)
149+
}
142150
}
143151

144152
val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0)
145-
logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
153+
logTrace(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
146154
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
147155
new UcxCallback() {
148156
override def onSuccess(request: UcpRequest): Unit = {
@@ -179,7 +187,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
179187
}
180188

181189
private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId,
182-
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
190+
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
183191
val stats = new UcxStats()
184192
val ep = getConnection(executorId)
185193
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
@@ -206,28 +214,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
206214
override def onSuccess(request: UcpRequest): Unit = {
207215
memoryPool.put(mem)
208216
}
209-
})
217+
})
210218

211219
val result = new UcxSuccessOperationResult(stats)
212220
val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size,
213221
tag, -1L, new UcxCallback () {
214222

215-
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
216-
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
217-
s" of size: ${resultBuffer.size}: $errorMsg")
218-
if (cb != null ) {
219-
cb.onComplete(new UcxFailureOperationResult(errorMsg))
223+
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
224+
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
225+
s" of size: ${resultBuffer.size}: $errorMsg")
226+
if (cb != null ) {
227+
cb.onComplete(new UcxFailureOperationResult(errorMsg))
228+
}
220229
}
221-
}
222230

223-
override def onSuccess(request: UcpRequest): Unit = {
224-
stats.endTime = System.nanoTime()
225-
stats.receiveSize = request.getRecvSize
226-
if (cb != null) {
227-
cb.onComplete(result)
231+
override def onSuccess(request: UcpRequest): Unit = {
232+
stats.endTime = System.nanoTime()
233+
stats.receiveSize = request.getRecvSize
234+
if (cb != null) {
235+
cb.onComplete(result)
236+
}
228237
}
229-
}
230-
})
238+
})
231239
new UcxRequest(request, stats)
232240
}
233241

0 commit comments

Comments
 (0)