Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions core/src/main/scala/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with

val nextRunId = new AtomicInteger(0)

val runIdToStageIds = new HashMap[Int, HashSet[Int]]

val nextStageId = new AtomicInteger(0)

val idToStage = new TimeStampedHashMap[Int, Stage]
Expand Down Expand Up @@ -143,6 +145,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
idToStage(id) = stage
val stageIdSet = runIdToStageIds.getOrElseUpdate(priority, new HashSet)
stageIdSet += id
stage
}

Expand Down Expand Up @@ -285,6 +289,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
removeStages(job)
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
Expand Down Expand Up @@ -420,13 +425,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
running -= stage
removeStages(job)
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
Expand Down Expand Up @@ -558,9 +564,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
activeJobs -= job
resultStageToJob -= resultStage
removeStages(job)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
Expand Down Expand Up @@ -637,6 +644,19 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
}

def removeStages(job: ActiveJob) = {
runIdToStageIds(job.runId).foreach(stageId => {
idToStage.get(stageId).map( stage => {
pendingTasks -= stage
waiting -= stage
running -= stage
failed -= stage
})
idToStage -= stageId
})
runIdToStageIds -= job.runId
}

def stop() {
eventQueue.put(StopDAGScheduler)
metadataCleaner.cancel()
Expand Down
88 changes: 88 additions & 0 deletions core/src/test/scala/spark/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package spark

import org.scalatest.FunSuite
import scheduler.{DAGScheduler, TaskSchedulerListener, TaskSet, TaskScheduler}
import collection.mutable

class TaskSchedulerMock(f: (Int) => TaskEndReason ) extends TaskScheduler {
// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null
var taskCount = 0

override def start(): Unit = {}

// Disconnect from the cluster.
override def stop(): Unit = {}

// Submit a sequence of tasks to run.
override def submitTasks(taskSet: TaskSet): Unit = {
taskSet.tasks.foreach( task => {
val m = new mutable.HashMap[Long, Any]()
m.put(task.stageId, 1)
taskCount += 1
listener.taskEnded(task, f(taskCount), 1, m)
})
}

// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}

// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
override def defaultParallelism(): Int = {
2
}
}

class DAGSchedulerSuite extends FunSuite {
def assertDagSchedulerEmpty(dagScheduler: DAGScheduler) = {
assert(dagScheduler.pendingTasks.isEmpty)
assert(dagScheduler.activeJobs.isEmpty)
assert(dagScheduler.failed.isEmpty)
assert(dagScheduler.runIdToStageIds.isEmpty)
assert(dagScheduler.idToStage.isEmpty)
assert(dagScheduler.resultStageToJob.isEmpty)
assert(dagScheduler.running.isEmpty)
assert(dagScheduler.shuffleToMapStage.isEmpty)
assert(dagScheduler.waiting.isEmpty)
}

test("oneGoodJob") {
val sc = new SparkContext("local", "test")
val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
try {
val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
val func = (tc: TaskContext, iter: Iterator[Int]) => 1
val callSite = Utils.getSparkCallSite

val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
assertDagSchedulerEmpty(dagScheduler)
} finally {
dagScheduler.stop()
sc.stop()
// pause to let dagScheduler stop (separate thread)
Thread.sleep(10)
}
}

test("manyGoodJobs") {
val sc = new SparkContext("local", "test")
val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
try {
val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
val func = (tc: TaskContext, iter: Iterator[Int]) => 1
val callSite = Utils.getSparkCallSite

1.to(100).foreach( v => {
val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
})
assertDagSchedulerEmpty(dagScheduler)
} finally {
dagScheduler.stop()
sc.stop()
// pause to let dagScheduler stop (separate thread)
Thread.sleep(10)
}
}
}