Skip to content

Commit e4ea60b

Browse files
rxinmarkhamstra
authored andcommitted
Merge pull request #87 from aarondav/shuffle-base
Basic shuffle file consolidation The Spark shuffle phase can produce a large number of files, as one file is created per mapper per reducer. For large or repeated jobs, this often produces millions of shuffle files, which sees extremely degredaded performance from the OS file system. This patch seeks to reduce that burden by combining multipe shuffle files into one. This PR draws upon the work of @jason-dai in mesos/spark#669. However, it simplifies the design in order to get the majority of the gain with less overall intellectual and code burden. The vast majority of code in this pull request is a refactor to allow the insertion of a clean layer of indirection between logical block ids and physical files. This, I feel, provides some design clarity in addition to enabling shuffle file consolidation. The main goal is to produce one shuffle file per reducer per active mapper thread. This allows us to isolate the mappers (simplifying the failure modes), while still allowing us to reduce the number of mappers tremendously for large tasks. In order to accomplish this, we simply create a new set of shuffle files for every parallel task, and return the files to a pool which will be given out to the next run task. I have run some ad hoc query testing on 5 m1.xlarge EC2 nodes with 2g of executor memory and the following microbenchmark: scala> val nums = sc.parallelize(1 to 1000, 1000).flatMap(x => (1 to 1e6.toInt)) scala> def time(x: => Unit) = { val now = System.currentTimeMillis; x; System.currentTimeMillis - now } scala> (1 to 8).map(_ => time(nums.map(x => (x % 100000, 2000, x)).reduceByKey(_ + _).count) / 1000.0) For this particular workload, with 1000 mappers and 2000 reducers, I saw the old method running at around 15 minutes, with the consolidated shuffle files running at around 4 minutes. There was a very sharp increase in running time for the non-consolidated version after around 1 million total shuffle files. Below this threshold, however, there wasn't a significant difference between the two. Better performance measurement of this patch is warranted, and I plan on doing so in the near future as part of a general investigation of our shuffle file bottlenecks and performance. (cherry picked from commit 48952d6) Signed-off-by: Reynold Xin <[email protected]>
1 parent bee3cea commit e4ea60b

File tree

13 files changed

+460
-319
lines changed

13 files changed

+460
-319
lines changed

core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.netty.channel.DefaultFileRegion;
2626

2727
import org.apache.spark.storage.BlockId;
28+
import org.apache.spark.storage.FileSegment;
2829

2930
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
3031

@@ -37,40 +38,34 @@ public FileServerHandler(PathResolver pResolver){
3738
@Override
3839
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
3940
BlockId blockId = BlockId.apply(blockIdString);
40-
String path = pResolver.getAbsolutePath(blockId.name());
41-
// if getFilePath returns null, close the channel
42-
if (path == null) {
41+
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
42+
// if getBlockLocation returns null, close the channel
43+
if (fileSegment == null) {
4344
//ctx.close();
4445
return;
4546
}
46-
File file = new File(path);
47+
File file = fileSegment.file();
4748
if (file.exists()) {
4849
if (!file.isFile()) {
49-
//logger.info("Not a file : " + file.getAbsolutePath());
5050
ctx.write(new FileHeader(0, blockId).buffer());
5151
ctx.flush();
5252
return;
5353
}
54-
long length = file.length();
54+
long length = fileSegment.length();
5555
if (length > Integer.MAX_VALUE || length <= 0) {
56-
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
5756
ctx.write(new FileHeader(0, blockId).buffer());
5857
ctx.flush();
5958
return;
6059
}
6160
int len = new Long(length).intValue();
62-
//logger.info("Sending block "+blockId+" filelen = "+len);
63-
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
6461
ctx.write((new FileHeader(len, blockId)).buffer());
6562
try {
6663
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
67-
.getChannel(), 0, file.length()));
64+
.getChannel(), fileSegment.offset(), fileSegment.length()));
6865
} catch (Exception e) {
69-
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
7066
e.printStackTrace();
7167
}
7268
} else {
73-
//logger.warning("File not found: " + file.getAbsolutePath());
7469
ctx.write(new FileHeader(0, blockId).buffer());
7570
}
7671
ctx.flush();

core/src/main/java/org/apache/spark/network/netty/PathResolver.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717

1818
package org.apache.spark.network.netty;
1919

20+
import org.apache.spark.storage.BlockId;
21+
import org.apache.spark.storage.FileSegment;
2022

2123
public interface PathResolver {
22-
/**
23-
* Get the absolute path of the file
24-
*
25-
* @param fileId
26-
* @return the absolute path of file
27-
*/
28-
public String getAbsolutePath(String fileId);
24+
/** Get the file segment in which the given block resides. */
25+
public FileSegment getBlockLocation(BlockId blockId);
2926
}

core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.File
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.util.Utils
24-
import org.apache.spark.storage.BlockId
24+
import org.apache.spark.storage.{BlockId, FileSegment}
2525

2626

2727
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
5454
val localDirs = args.drop(2).map(new File(_))
5555

5656
val pResovler = new PathResolver {
57-
override def getAbsolutePath(blockIdString: String): String = {
58-
val blockId = BlockId(blockIdString)
57+
override def getBlockLocation(blockId: BlockId): FileSegment = {
5958
if (!blockId.isShuffle) {
6059
throw new Exception("Block " + blockId + " is not a shuffle block")
6160
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
6564
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
6665
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
6766
val file = new File(subDir, blockId.name)
68-
return file.getAbsolutePath
67+
return new FileSegment(file, 0, file.length())
6968
}
7069
}
7170
val sender = new ShuffleSender(port, pResovler)

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ private[spark] class ShuffleMapTask(
167167
var totalTime = 0L
168168
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
169169
writer.commit()
170-
writer.close()
171-
val size = writer.size()
170+
val size = writer.fileSegment().length
172171
totalBytes += size
173172
totalTime += writer.timeWriting()
174173
MapOutputTracker.compressSize(size)
@@ -191,6 +190,7 @@ private[spark] class ShuffleMapTask(
191190
} finally {
192191
// Release the writers back to the shuffle block manager.
193192
if (shuffle != null && buckets != null) {
193+
buckets.writers.foreach(_.close())
194194
shuffle.releaseWriters(buckets)
195195
}
196196
// Execute the callbacks on task completion.

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
4545
*/
4646
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
4747

48-
def run(attemptId: Long): T = {
48+
final def run(attemptId: Long): T = {
4949
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
5050
if (_killed) {
5151
kill()

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
2828
import akka.util.Duration
2929
import akka.util.duration._
3030

31-
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
31+
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
3232

3333
import org.apache.spark.{Logging, SparkEnv, SparkException}
3434
import org.apache.spark.io.CompressionCodec
@@ -102,18 +102,19 @@ private[spark] class BlockManager(
102102
}
103103

104104
val shuffleBlockManager = new ShuffleBlockManager(this)
105+
val diskBlockManager = new DiskBlockManager(
106+
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
105107

106108
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
107109

108110
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
109-
private[storage] val diskStore: DiskStore =
110-
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
111+
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
111112

112113
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
113114
private val nettyPort: Int = {
114115
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
115116
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
116-
if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
117+
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
117118
}
118119

119120
val connectionManager = new ConnectionManager(0)
@@ -567,16 +568,20 @@ private[spark] class BlockManager(
567568

568569
/**
569570
* A short circuited method to get a block writer that can write data directly to disk.
571+
* The Block will be appended to the File specified by filename.
570572
* This is currently used for writing shuffle files out. Callers should handle error
571573
* cases.
572574
*/
573-
def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
575+
def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
574576
: BlockObjectWriter = {
575-
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
577+
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
578+
val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
579+
val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
576580
writer.registerCloseEventHandler(() => {
581+
diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
577582
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
578583
blockInfo.put(blockId, myInfo)
579-
myInfo.markReady(writer.size())
584+
myInfo.markReady(writer.fileSegment().length)
580585
})
581586
writer
582587
}
@@ -988,13 +993,24 @@ private[spark] class BlockManager(
988993
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
989994
}
990995

996+
/** Serializes into a stream. */
997+
def dataSerializeStream(
998+
blockId: BlockId,
999+
outputStream: OutputStream,
1000+
values: Iterator[Any],
1001+
serializer: Serializer = defaultSerializer) {
1002+
val byteStream = new FastBufferedOutputStream(outputStream)
1003+
val ser = serializer.newInstance()
1004+
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
1005+
}
1006+
1007+
/** Serializes into a byte buffer. */
9911008
def dataSerialize(
9921009
blockId: BlockId,
9931010
values: Iterator[Any],
9941011
serializer: Serializer = defaultSerializer): ByteBuffer = {
9951012
val byteStream = new FastByteArrayOutputStream(4096)
996-
val ser = serializer.newInstance()
997-
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
1013+
dataSerializeStream(blockId, byteStream, values, serializer)
9981014
byteStream.trim()
9991015
ByteBuffer.wrap(byteStream.array)
10001016
}

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
package org.apache.spark.storage
1919

20+
import java.io.{FileOutputStream, File, OutputStream}
21+
import java.nio.channels.FileChannel
22+
23+
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
24+
25+
import org.apache.spark.Logging
26+
import org.apache.spark.serializer.{SerializationStream, Serializer}
2027

2128
/**
2229
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -59,12 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
5966
def write(value: Any)
6067

6168
/**
62-
* Size of the valid writes, in bytes.
69+
* Returns the file segment of committed data that this Writer has written.
6370
*/
64-
def size(): Long
71+
def fileSegment(): FileSegment
6572

6673
/**
6774
* Cumulative time spent performing blocking writes, in ns.
6875
*/
6976
def timeWriting(): Long
7077
}
78+
79+
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
80+
class DiskBlockObjectWriter(
81+
blockId: BlockId,
82+
file: File,
83+
serializer: Serializer,
84+
bufferSize: Int,
85+
compressStream: OutputStream => OutputStream)
86+
extends BlockObjectWriter(blockId)
87+
with Logging
88+
{
89+
90+
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
91+
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
92+
def timeWriting = _timeWriting
93+
private var _timeWriting = 0L
94+
95+
private def callWithTiming(f: => Unit) = {
96+
val start = System.nanoTime()
97+
f
98+
_timeWriting += (System.nanoTime() - start)
99+
}
100+
101+
def write(i: Int): Unit = callWithTiming(out.write(i))
102+
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
103+
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
104+
}
105+
106+
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
107+
108+
/** The file channel, used for repositioning / truncating the file. */
109+
private var channel: FileChannel = null
110+
private var bs: OutputStream = null
111+
private var fos: FileOutputStream = null
112+
private var ts: TimeTrackingOutputStream = null
113+
private var objOut: SerializationStream = null
114+
private var initialPosition = 0L
115+
private var lastValidPosition = 0L
116+
private var initialized = false
117+
private var _timeWriting = 0L
118+
119+
override def open(): BlockObjectWriter = {
120+
fos = new FileOutputStream(file, true)
121+
ts = new TimeTrackingOutputStream(fos)
122+
channel = fos.getChannel()
123+
initialPosition = channel.position
124+
lastValidPosition = initialPosition
125+
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
126+
objOut = serializer.newInstance().serializeStream(bs)
127+
initialized = true
128+
this
129+
}
130+
131+
override def close() {
132+
if (initialized) {
133+
if (syncWrites) {
134+
// Force outstanding writes to disk and track how long it takes
135+
objOut.flush()
136+
val start = System.nanoTime()
137+
fos.getFD.sync()
138+
_timeWriting += System.nanoTime() - start
139+
}
140+
objOut.close()
141+
142+
_timeWriting += ts.timeWriting
143+
144+
channel = null
145+
bs = null
146+
fos = null
147+
ts = null
148+
objOut = null
149+
}
150+
// Invoke the close callback handler.
151+
super.close()
152+
}
153+
154+
override def isOpen: Boolean = objOut != null
155+
156+
override def commit(): Long = {
157+
if (initialized) {
158+
// NOTE: Flush the serializer first and then the compressed/buffered output stream
159+
objOut.flush()
160+
bs.flush()
161+
val prevPos = lastValidPosition
162+
lastValidPosition = channel.position()
163+
lastValidPosition - prevPos
164+
} else {
165+
// lastValidPosition is zero if stream is uninitialized
166+
lastValidPosition
167+
}
168+
}
169+
170+
override def revertPartialWrites() {
171+
if (initialized) {
172+
// Discard current writes. We do this by flushing the outstanding writes and
173+
// truncate the file to the last valid position.
174+
objOut.flush()
175+
bs.flush()
176+
channel.truncate(lastValidPosition)
177+
}
178+
}
179+
180+
override def write(value: Any) {
181+
if (!initialized) {
182+
open()
183+
}
184+
objOut.writeObject(value)
185+
}
186+
187+
override def fileSegment(): FileSegment = {
188+
val bytesWritten = lastValidPosition - initialPosition
189+
new FileSegment(file, initialPosition, bytesWritten)
190+
}
191+
192+
// Only valid if called after close()
193+
override def timeWriting() = _timeWriting
194+
}

0 commit comments

Comments
 (0)