Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ See file LICENSE for terms.
<scala.version>2.12.12</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<jucx.version>1.10.0-SNAPSHOT</jucx.version>
<cudf.version>0.16</cudf.version>
</properties>

<dependencies>
Expand All @@ -55,6 +56,12 @@ See file LICENSE for terms.
<version>3.2.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -137,6 +144,18 @@ See file LICENSE for terms.
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.2.0</version>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer,
blockIds: Seq[BlockId]) {

logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}")
logDebug(s"Prefetching blocks: ${blockIds.mkString(",")}")
clientConnections.getOrElseUpdate(workerId,
globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress))
)
Expand Down Expand Up @@ -184,7 +184,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
lock.lock()
val blockMemory = block.getMemoryBlock

logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag")
logDebug(s"Sending $blockId of size ${blockMemory.size} to tag: $tag")
ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback {
override def onSuccess(request: UcpRequest): Unit = {
if (block.isInstanceOf[UcxPinnedBlock]) {
Expand Down
62 changes: 35 additions & 27 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.apache.spark.shuffle.ucx

import java.io.{Closeable, ObjectOutputStream}
import java.nio.BufferOverflowException
import java.util.concurrent.ThreadLocalRandom

import scala.collection.mutable
Expand Down Expand Up @@ -91,7 +92,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId")
val endpointParams = new UcpEndpointParams()
.setUcpAddress(workerAdresses.get(executorId))
worker.newEndpoint(endpointParams)
worker.newEndpoint(endpointParams)
})
}

Expand All @@ -115,16 +116,16 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon

ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize,
UcxRpcMessages.PREFETCH_TAG, new UcxCallback() {
override def onSuccess(request: UcpRequest): Unit = {
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
memoryPool.put(mem)
}
})
override def onSuccess(request: UcpRequest): Unit = {
logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
memoryPool.put(mem)
}
})
}

private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
resultBuffer: Seq[MemoryBlock],
callbacks: Seq[OperationCallback]): Seq[Request] = {
resultBuffer: Seq[MemoryBlock],
callbacks: Seq[OperationCallback]): Seq[Request] = {
val ep = getConnection(executorId)
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
val buffer = UcxUtils.getByteBufferView(mem.address,
Expand All @@ -136,13 +137,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon

Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
val out = new ObjectOutputStream(bos)
out.writeObject(message)
out.flush()
out.close()
try {
out.writeObject(message)
out.flush()
out.close()
} catch {
case _: BufferOverflowException =>
throw new UcxException(s"Prefetch blocks message size > " +
s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}")
case ex: Exception => throw new UcxException(ex.getMessage)
}
}

val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0)
logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
logTrace(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
new UcxCallback() {
override def onSuccess(request: UcpRequest): Unit = {
Expand Down Expand Up @@ -179,7 +187,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
}

private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId,
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
val stats = new UcxStats()
val ep = getConnection(executorId)
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
Expand All @@ -206,28 +214,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
override def onSuccess(request: UcpRequest): Unit = {
memoryPool.put(mem)
}
})
})

val result = new UcxSuccessOperationResult(stats)
val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size,
tag, -1L, new UcxCallback () {

override def onError(ucsStatus: Int, errorMsg: String): Unit = {
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
s" of size: ${resultBuffer.size}: $errorMsg")
if (cb != null ) {
cb.onComplete(new UcxFailureOperationResult(errorMsg))
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
s" of size: ${resultBuffer.size}: $errorMsg")
if (cb != null ) {
cb.onComplete(new UcxFailureOperationResult(errorMsg))
}
}
}

override def onSuccess(request: UcpRequest): Unit = {
stats.endTime = System.nanoTime()
stats.receiveSize = request.getRecvSize
if (cb != null) {
cb.onComplete(result)
override def onSuccess(request: UcpRequest): Unit = {
stats.endTime = System.nanoTime()
stats.receiveSize = request.getRecvSize
if (cb != null) {
cb.onComplete(result)
}
}
}
})
})
new UcxRequest(request, stats)
}

Expand Down
Loading