5
5
package org .apache .spark .shuffle .ucx
6
6
7
7
import java .io .{Closeable , ObjectOutputStream }
8
+ import java .nio .BufferOverflowException
8
9
import java .util .concurrent .ThreadLocalRandom
9
10
10
11
import scala .collection .mutable
@@ -91,7 +92,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
91
92
logInfo(s " Worker from thread ${Thread .currentThread().getName} connecting to $executorId" )
92
93
val endpointParams = new UcpEndpointParams ()
93
94
.setUcpAddress(workerAdresses.get(executorId))
94
- worker.newEndpoint(endpointParams)
95
+ worker.newEndpoint(endpointParams)
95
96
})
96
97
}
97
98
@@ -115,16 +116,16 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
115
116
116
117
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize,
117
118
UcxRpcMessages .PREFETCH_TAG , new UcxCallback () {
118
- override def onSuccess (request : UcpRequest ): Unit = {
119
- logTrace(s " Sent prefetch ${blockIds.length} blocks to $executorId" )
120
- memoryPool.put(mem)
121
- }
122
- })
119
+ override def onSuccess (request : UcpRequest ): Unit = {
120
+ logTrace(s " Sent prefetch ${blockIds.length} blocks to $executorId" )
121
+ memoryPool.put(mem)
122
+ }
123
+ })
123
124
}
124
125
125
126
private [ucx] def fetchBlocksByBlockIds (executorId : String , blockIds : Seq [BlockId ],
126
- resultBuffer : Seq [MemoryBlock ],
127
- callbacks : Seq [OperationCallback ]): Seq [Request ] = {
127
+ resultBuffer : Seq [MemoryBlock ],
128
+ callbacks : Seq [OperationCallback ]): Seq [Request ] = {
128
129
val ep = getConnection(executorId)
129
130
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
130
131
val buffer = UcxUtils .getByteBufferView(mem.address,
@@ -136,13 +137,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
136
137
137
138
Utils .tryWithResource(new ByteBufferBackedOutputStream (buffer)) { bos =>
138
139
val out = new ObjectOutputStream (bos)
139
- out.writeObject(message)
140
- out.flush()
141
- out.close()
140
+ try {
141
+ out.writeObject(message)
142
+ out.flush()
143
+ out.close()
144
+ } catch {
145
+ case _ : BufferOverflowException =>
146
+ throw new UcxException (s " Prefetch blocks message size > " +
147
+ s " ${transport.ucxShuffleConf.RPC_MESSAGE_SIZE .key}: ${transport.ucxShuffleConf.rpcMessageSize}" )
148
+ case ex : Exception => throw new UcxException (ex.getMessage)
149
+ }
142
150
}
143
151
144
152
val tag = ThreadLocalRandom .current().nextLong(Long .MinValue , 0 )
145
- logInfo (s " Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag" )
153
+ logTrace (s " Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag" )
146
154
ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
147
155
new UcxCallback () {
148
156
override def onSuccess (request : UcpRequest ): Unit = {
@@ -179,7 +187,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
179
187
}
180
188
181
189
private [ucx] def fetchBlockByBlockId (executorId : String , blockId : BlockId ,
182
- resultBuffer : MemoryBlock , cb : OperationCallback ): UcxRequest = {
190
+ resultBuffer : MemoryBlock , cb : OperationCallback ): UcxRequest = {
183
191
val stats = new UcxStats ()
184
192
val ep = getConnection(executorId)
185
193
val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
@@ -206,28 +214,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
206
214
override def onSuccess (request : UcpRequest ): Unit = {
207
215
memoryPool.put(mem)
208
216
}
209
- })
217
+ })
210
218
211
219
val result = new UcxSuccessOperationResult (stats)
212
220
val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size,
213
221
tag, - 1L , new UcxCallback () {
214
222
215
- override def onError (ucsStatus : Int , errorMsg : String ): Unit = {
216
- logError(s " Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
217
- s " of size: ${resultBuffer.size}: $errorMsg" )
218
- if (cb != null ) {
219
- cb.onComplete(new UcxFailureOperationResult (errorMsg))
223
+ override def onError (ucsStatus : Int , errorMsg : String ): Unit = {
224
+ logError(s " Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
225
+ s " of size: ${resultBuffer.size}: $errorMsg" )
226
+ if (cb != null ) {
227
+ cb.onComplete(new UcxFailureOperationResult (errorMsg))
228
+ }
220
229
}
221
- }
222
230
223
- override def onSuccess (request : UcpRequest ): Unit = {
224
- stats.endTime = System .nanoTime()
225
- stats.receiveSize = request.getRecvSize
226
- if (cb != null ) {
227
- cb.onComplete(result)
231
+ override def onSuccess (request : UcpRequest ): Unit = {
232
+ stats.endTime = System .nanoTime()
233
+ stats.receiveSize = request.getRecvSize
234
+ if (cb != null ) {
235
+ cb.onComplete(result)
236
+ }
228
237
}
229
- }
230
- })
238
+ })
231
239
new UcxRequest (request, stats)
232
240
}
233
241
0 commit comments