Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
5da2773
test
gleon99 Feb 2, 2023
cca6e6b
Fetch 1 block each time
gleon99 Feb 9, 2023
619b306
Add logs
gleon99 Feb 19, 2023
740e2b7
Remove "server" mode from UcxWorkerWrapper
gleon99 Feb 19, 2023
7c9dafc
Init shuffle transport w local DPU address
gleon99 Feb 19, 2023
df8e7a4
Add ini4j dependency
gleon99 Feb 19, 2023
f629470
Modify shuffle transport init, remove obsolete code
gleon99 Feb 19, 2023
8423a75
Remove AM handlers from global RPC thread
gleon99 Feb 19, 2023
a2412f9
Log executor connection
gleon99 Feb 19, 2023
8ce0397
Send fetch block am (broken)
gleon99 Feb 21, 2023
8d14883
Typo fix (BockId -> BlockId)
gleon99 Feb 21, 2023
6f616ec
Don't send shuffleId in fetch blocks request (temporary)
gleon99 Feb 21, 2023
67f94ae
Add files: DPU utils; DPU shuffle IO; Definitions; Add logs
gleon99 Feb 22, 2023
fe49cfe
Add DpuShuffleIO and DpuShuffleExecutorComponents
gleon99 Feb 23, 2023
7a9a5ed
Add UcxShuffleMapOutputWriter (copy from UcxLocalDiskShuffleExecutorC…
gleon99 Feb 25, 2023
2b0ca76
Add TODO
gleon99 Feb 26, 2023
a66b21c
Logs in DpuShuffleMapOutputWriter
gleon99 Feb 26, 2023
07aed6e
Connect to local DPU
gleon99 Feb 26, 2023
5668dc0
Cleanup, logs, add dummy Nvkv calls
gleon99 Feb 27, 2023
9e48d82
Rename DpuShuffleMapOutputWriter to NvkvShuffleMapOutputWriter
gleon99 Feb 27, 2023
2cba8a9
SHUFFLE/EXECUTER: Create NvkvWorker instance for writing and reading and
ofirfarjun7 May 18, 2023
6123e78
SHUFFLE/EXECUTER/WRITER: Renaming
ofirfarjun7 May 21, 2023
46800fb
SHUFFLE/EXECUTER/WRITER: Write with offset
ofirfarjun7 May 21, 2023
d3416b3
SHUFFLE/EXECUTER/WRITER: change java file to scala
ofirfarjun7 May 22, 2023
53ad490
Print write info and send init exec AM
ofirfarjun7 May 28, 2023
bf9ee55
Pack init exec msg. need to test
ofirfarjun7 May 29, 2023
4c08651
Pack and send init exec msg. Tested unpacking on DPU
ofirfarjun7 May 30, 2023
34ca675
Send mappers info to the DPU and write partitions to 512 aligned address
ofirfarjun7 May 31, 2023
e2b8fc2
Activate writes to nvkv and test (read and compare)
ofirfarjun7 May 31, 2023
29b05d4
Increase bounce buffer block size
ofirfarjun7 May 31, 2023
1b571a7
Init nvkv handler in shuffle manager
ofirfarjun7 Jun 3, 2023
ca30878
Make NvkvHandler member of ucxShuffleTransport
ofirfarjun7 Jun 4, 2023
bafc8bd
Add TODO's
ofirfarjun7 Jun 4, 2023
20588da
Read local blocks from nvkv
ofirfarjun7 Jun 6, 2023
645c880
Read local blocks from nvkv
ofirfarjun7 Jun 6, 2023
4e2a40d
Read local blocks - improve nvkv storing
ofirfarjun7 Jun 6, 2023
15e358a
Read local blocks - improve code
ofirfarjun7 Jun 6, 2023
52bae56
Read local blcks from DPU
ofirfarjun7 Jun 12, 2023
bee3485
Read local blcks from DPU
ofirfarjun7 Jun 12, 2023
81000e4
Pass args to spark DPU test
ofirfarjun7 Jun 14, 2023
d93fde8
GroupByTest is working with two nodes
ofirfarjun7 Jun 21, 2023
51c4b86
Change AM ID accourding to new spark_service ver
ofirfarjun7 Jun 21, 2023
480e7ac
Perform connection establishment properly
ofirfarjun7 Jun 22, 2023
45baabc
Support general number of mappers/reducers
ofirfarjun7 Jun 22, 2023
95d77b7
Remove writes to HDD from writer code
ofirfarjun7 Jun 25, 2023
eb1e2d7
Change logs to debug+increase bb size+reduce mem copy
ofirfarjun7 Jun 26, 2023
713c5ca
make fetch remote block async
ofirfarjun7 Jul 1, 2023
df089ec
Measure req time
ofirfarjun7 Jul 3, 2023
4e49f9a
Use nvkv buffer when reading to save copy
ofirfarjun7 Jul 4, 2023
cd65448
Improve code, remove zcopy to BB
ofirfarjun7 Jul 5, 2023
ffdf0cd
Increase number of fectchBlock request sent before progress
ofirfarjun7 Jul 5, 2023
bb0c211
Increase number of fectchBlock request sent before progress
ofirfarjun7 Jul 5, 2023
5ab0f9b
Add receive mpool to Nvkv
ofirfarjun7 Jul 6, 2023
8a24b78
Add NvkvManagedBuffer extending ManageBuffer
ofirfarjun7 Jul 6, 2023
4211a81
Add NvkvManagedBuffer
ofirfarjun7 Jul 6, 2023
90dd4a8
remove redundant code and fix bug, buffer len
ofirfarjun7 Jul 17, 2023
66e2b20
Fetch remote while fetching local
ofirfarjun7 Jul 17, 2023
aea1e09
Remove redundant code
ofirfarjun7 Jul 18, 2023
f049329
Spark writer might skip reduce partitions
ofirfarjun7 Jul 20, 2023
67c82c0
Bug fix
ofirfarjun7 Jul 22, 2023
271dad0
Treat all blocks as remote
ofirfarjun7 Jul 27, 2023
7c2bd1d
Running with multiple execs - still not stable
ofirfarjun7 Jul 27, 2023
937aa4b
Treat all blocks as remote
ofirfarjun7 Jul 31, 2023
ca8a8a3
Support multiple nvme devices
ofirfarjun7 Jul 31, 2023
fe1745b
Bug fix - writing to illegal offset
ofirfarjun7 Aug 1, 2023
90ab7c5
Throw exception when using nvkv wrongly
ofirfarjun7 Aug 2, 2023
038665e
remove redundant EPs
ofirfarjun7 Aug 2, 2023
2cbdbd6
code improvements
ofirfarjun7 Aug 8, 2023
a8b9c28
code improvements - rename NvkvHandler
ofirfarjun7 Aug 8, 2023
a95b32c
code improvements
ofirfarjun7 Aug 8, 2023
465b602
code improvements
ofirfarjun7 Aug 8, 2023
53a5248
code improvements - wait for RPC if needed
ofirfarjun7 Aug 8, 2023
154acaa
code improvements
ofirfarjun7 Aug 8, 2023
317a068
code improvements - safe polling
ofirfarjun7 Aug 8, 2023
3ab7876
code improvements - safe polling
ofirfarjun7 Aug 8, 2023
fa4329b
code improvements - safe polling
ofirfarjun7 Aug 8, 2023
20d51cf
code improvements
ofirfarjun7 Aug 8, 2023
e7d6a6b
code improvements
ofirfarjun7 Aug 8, 2023
6cdcd7d
rearrange transport init and remove redundant code
ofirfarjun7 Aug 9, 2023
69ab9c2
rearrange transport init and remove redundant code
ofirfarjun7 Aug 9, 2023
c2d23c3
code improvements
ofirfarjun7 Aug 9, 2023
48762e7
code improvements - moving progress CB to send fetch block success
ofirfarjun7 Aug 10, 2023
dd96755
code improvements - Workaround to nvkv_query issue, be advised to check
ofirfarjun7 Aug 10, 2023
f63587a
code improvements - safe polling
ofirfarjun7 Aug 12, 2023
827d85d
code improvements - safe polling
ofirfarjun7 Aug 12, 2023
25e9463
code improvements - change nvme device address to new
ofirfarjun7 Aug 14, 2023
49b43bc
Perf issue - Move fetchBlock CB to previous, revert back to nvkv_query
ofirfarjun7 Aug 14, 2023
0c7f115
code improvements - Set nvkv core_mask. temporary value, not ideal
ofirfarjun7 Aug 14, 2023
be6913e
RPC mutex mechanism to sync spdk process
ofirfarjun7 Aug 14, 2023
52e1816
Remove unused files
ofirfarjun7 Aug 15, 2023
b7a1648
Revert shuffleBlockFetcherIterator changes
ofirfarjun7 Aug 15, 2023
81f14db
patch shuffleBlockFetcherIterator to mark all blocks as remote
ofirfarjun7 Aug 15, 2023
c472486
Make sure block location was sent to DPU before moving to read stage
ofirfarjun7 Aug 16, 2023
d4dfb6f
Address PR comments - Cleanup and improvements
ofirfarjun7 Aug 19, 2023
db2ae5c
Address PR comments - Cleanup and improvements
ofirfarjun7 Aug 19, 2023
4e76198
Address PR comments - Cleanup and improvements
ofirfarjun7 Aug 19, 2023
407f4cf
Address PR comments - Cleanup and improvements
ofirfarjun7 Aug 19, 2023
d9ff8bd
Address PR coments - code improvements
ofirfarjun7 Aug 19, 2023
8220e6f
Make byte order independent of local arch
ofirfarjun7 Aug 19, 2023
353b737
Address PR comments - improve logging
ofirfarjun7 Aug 19, 2023
e696749
Address PR comments - improve code
ofirfarjun7 Aug 19, 2023
5179646
Address PR comments - improve doc
ofirfarjun7 Aug 19, 2023
cb97f8f
Address PR comments - code improvements
ofirfarjun7 Aug 21, 2023
e5ba6ef
Address PR comments - cleanup channelWrapper, not supporting it
ofirfarjun7 Aug 21, 2023
a7533e7
Address PR comments - code improvements
ofirfarjun7 Aug 21, 2023
d88b861
Address PR comments - code improvements
ofirfarjun7 Aug 21, 2023
514d25f
Address PR comments - code improvements
ofirfarjun7 Aug 21, 2023
99b9f55
Address PR comments - code improvements
ofirfarjun7 Aug 21, 2023
0b1033e
Address PR comments - Remove LEO
ofirfarjun7 Aug 21, 2023
2b90d23
Address PR comments - code improvements
ofirfarjun7 Aug 23, 2023
55a5e54
Address PR comments - code improvements
ofirfarjun7 Aug 23, 2023
edd51a2
Forcing Spark Shuffle manager to use SortShuffleWriter
ofirfarjun7 Aug 23, 2023
370bc52
Forcing Spark Shuffle manager to use SortShuffleWriter - patch
ofirfarjun7 Aug 23, 2023
a01df84
Distribute cores evenly between SPDK process - framework
ofirfarjun7 Aug 23, 2023
96a7c04
Set core mask and storage device
ofirfarjun7 Aug 23, 2023
84f39d2
Write to the executer storage device
ofirfarjun7 Aug 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,17 @@ See file LICENSE for terms.
<dependency>
<groupId>org.openucx</groupId>
<artifactId>jucx</artifactId>
<version>1.13.1</version>
<version>1.16.0</version>
</dependency>
<dependency>
<groupId>org.openucx</groupId>
<artifactId>nvkv</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ini4j</groupId>
<artifactId>ini4j</artifactId>
<version>0.5.4</version>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ package org.apache.spark.shuffle.compat.spark_2_4
import org.openucx.jucx.UcxUtils
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}

class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{
override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import java.util
import java.util.Optional

import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.utils.CommonUtils
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter}
import org.apache.spark.shuffle.UcxShuffleManager
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter}
import scala.concurrent.duration._

/**
* Entry point to UCX executor.
Expand All @@ -22,14 +24,18 @@ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)
private var blockResolver: UcxShuffleBlockResolver = _

override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
logDebug("UcxLocalDiskShuffleExecutorComponents initializeExecutor appId: " + appId + " execId: " + execId)
val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
while (ucxShuffleManager.ucxTransport == null) {
Thread.sleep(5)
}
CommonUtils.safePolling(() => {},
() => {ucxShuffleManager.ucxTransport == null},
10.seconds.fromNow,
"Got timeout when polling", Duration(5, "millis"))

blockResolver = ucxShuffleManager.shuffleBlockResolver
}

override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = {
logDebug("UcxLocalDiskShuffleExecutorComponents createMapOutputWriter shuffleId: " + shuffleId + " mapTaskId: " + mapTaskId + " numPartitions: " + numPartitions)
if (blockResolver == null) {
throw new IllegalStateException(
"Executor components must be initialized before getting writers.")
Expand All @@ -39,10 +45,14 @@ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)
}

override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = {
// Called for each mapper
logDebug("UcxLocalDiskShuffleExecutorComponents createSingleFileMapOutputWriter shuffleId: " + shuffleId + " mapId: " + mapId)
if (blockResolver == null) {
throw new IllegalStateException(
"Executor components must be initialized before getting writers.")
}

// Need to implement an alternative to LocalDiskSingleSpillMapOutputWriter?
Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
package org.apache.spark.shuffle.compat.spark_3_0

import java.io.{File, RandomAccessFile}
import java.nio.ByteBuffer

import org.apache.spark.{TaskContext, SparkEnv}
import org.apache.spark.storage._
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleTransport, CommonUcxShuffleBlockResolver, CommonUcxShuffleManager}

import org.apache.spark.TaskContext
import org.apache.spark.shuffle.ucx.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager}

/**
* Mapper entry point for UcxShuffle plugin. Performs memory registration
Expand All @@ -16,6 +21,8 @@ import org.apache.spark.shuffle.ucx.{CommonUcxShuffleBlockResolver, CommonUcxShu
class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {

val shuffleManager: CommonUcxShuffleManager = ucxShuffleManager


override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
lengths: Array[Long], dataTmp: File): Unit = {
Expand All @@ -29,4 +36,27 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
}
writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, new RandomAccessFile(dataFile, "r"))
}

override def getBlockData(
blockId: BlockId,
dirs: Option[Array[String]]): ManagedBuffer = {

val ucxTransport: UcxShuffleTransport = shuffleManager.ucxTransport

logInfo("UcxShuffleBlockResolver getBlockData")
val (shuffleId, mapId, startReduceId, endReduceId) = blockId match {
case id: ShuffleBlockId =>
(id.shuffleId, id.mapId, id.reduceId, id.reduceId + 1)
case batchId: ShuffleBlockBatchId =>
(batchId.shuffleId, batchId.mapId, batchId.startReduceId, batchId.endReduceId)
case _ =>
throw new IllegalArgumentException("unexpected shuffle block id format: " + blockId)
}

var length = ucxTransport.getNvkvWrapper.getPartitonLength(shuffleId, mapId, startReduceId).toInt
var offset = ucxTransport.getNvkvWrapper.getPartitonOffset(shuffleId, mapId, startReduceId)
logDebug(s"UcxShuffleBlockResolver - Reading shuffleId $shuffleId mapId $mapId reduceId $startReduceId at offset $offset with length $length from nvkv")
var resultBuffer = ucxTransport.getNvkvWrapper.read(length, offset)
new NioManagedBuffer(resultBuffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,57 @@
*/
package org.apache.spark.shuffle.compat.spark_3_0

import java.nio.ByteBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBlockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.{UnsafeUtils, CommonUtils}
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}
import org.apache.spark.SparkException

class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging {

override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
if (blockIds.length > transport.ucxShuffleConf.maxBlocksPerRequest) {
val (b1, b2) = blockIds.splitAt(blockIds.length / 2)
fetchBlocks(host, port, execId, b1, listener, downloadFileManager)
fetchBlocks(host, port, execId, b2, listener, downloadFileManager)
return
}
//TODO - check if we need to limit max number of request "on the air"
// if (blockIds.length > 32) {
// val (b1, b2) = blockIds.splitAt(blockIds.length / 2)
// fetchBlocks(host, port, execId, b1, listener, downloadFileManager)
// fetchBlocks(host, port, execId, b2, listener, downloadFileManager)
// return
// }

val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val ucxBlockIds = Array.ofDim[UcxShuffleBlockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
var send = 0
var receive = 0
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
send = send + 1
SparkBlockId.apply(blockIds(i)) match {
case blockId: SparkShuffleBlockId => {
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
}
})
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
transport.fetchBlocksByBlockIds(execId.toLong, Array(ucxBlockIds(i)), resultBufferAllocator,
Array(callbacks(i)), () => {receive = receive + 1})
}
case _ =>
throw new SparkException("Unrecognized blockId")
}
}
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks)
transport.progress()

CommonUtils.safePolling(() => {transport.progress()}, () => {send != receive})
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ package org.apache.spark.shuffle
import scala.collection.JavaConverters._

import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.compat.spark_3_0.{UcxLocalDiskShuffleExecutorComponents, UcxShuffleBlockResolver, UcxShuffleReader}
import org.apache.spark.shuffle.compat.spark_3_0.{UcxShuffleBlockResolver, UcxShuffleReader}
import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter}
import org.apache.spark.shuffle.ucx.CommonUcxShuffleManager
import org.apache.spark.shuffle.ucx.{CommonUcxShuffleManager, NvkvShuffleExecutorComponents}
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}

/**
* Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
* and injects needed logic in override methods.
*/

class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
extends CommonUcxShuffleManager(conf, isDriver) {

Expand All @@ -25,9 +26,11 @@ class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)

override def getWriter[K, V](handle: ShuffleHandle, mapId: ReduceId, context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
logInfo("UcxShuffleManager getWriter")
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] =>
logDebug("UcxShuffleManager getWriter unsafeShuffleHandle")
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
Expand All @@ -38,19 +41,22 @@ class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] =>
logDebug("UcxShuffleManager getWriter other")
new SortShuffleWriter(
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
}
}

override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId,
context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
logInfo("UcxShuffleManager getReader")
new UcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], startPartition, endPartition,
context, ucxTransport, readMetrics = metrics, shouldBatchFetch = false)
}

private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
val executorComponents = new UcxLocalDiskShuffleExecutorComponents(conf)
logInfo("UcxShuffleManager loadShuffleExecutorComponents")
val executorComponents = new NvkvShuffleExecutorComponents(conf, getTransport)
val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
.toMap
executorComponents.initializeExecutor(
Expand All @@ -59,5 +65,4 @@ class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
extraConfigs.asJava)
executorComponents
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark._
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.utils.CommonUtils
import org.apache.spark.shuffle.ucx.UcxShuffleTransport
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, ShuffleReader}
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId}
Expand Down Expand Up @@ -115,10 +116,11 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
// Do progress if queue is empty before calling next on ShuffleIterator
val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
override def next(): (BlockId, InputStream) = {

val startTime = System.nanoTime()
while (resultQueue.isEmpty) {
transport.progress()
}
CommonUtils.safePolling(() => {transport.progress()},
() => {resultQueue.isEmpty})

val fetchWaitTime = System.nanoTime() - startTime
readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime))
wrappedStreams.next()
Expand Down
Loading