From 5ea5a34af4a3e00f823938c48ad230aacc41a49b Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Sun, 20 Jul 2025 14:17:41 +0800 Subject: [PATCH] feat: Add shutdown watchdog configuration and tests - Introduced a new configuration property `kyuubi.session.engine.shutdown.watchdog.timeout` to manage the maximum wait time for engine shutdown. - Updated the `SparkSQLEngine` to utilize the new watchdog feature. - Minor adjustments to existing configurations and documentation to reflect the new feature. --- docs/configuration/settings.md | 1 + .../kyuubi/engine/spark/SparkSQLEngine.scala | 83 ++++- .../org/apache/kyuubi/config/KyuubiConf.scala | 11 + .../apache/kyuubi/util/ThreadDumpUtils.scala | 331 ++++++++++++++++++ .../kyuubi/util/ThreadDumpUtilsSuite.scala | 209 +++++++++++ 5 files changed, 630 insertions(+), 5 deletions(-) create mode 100644 kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadDumpUtils.scala create mode 100644 kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadDumpUtilsSuite.scala diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 8bbbcfed626..373233af272 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -494,6 +494,7 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.session.engine.open.onFailure | RETRY | The behavior when opening engine failed: | string | 1.8.1 | | kyuubi.session.engine.open.retry.wait | PT10S | How long to wait before retrying to open the engine after failure. | duration | 1.7.0 | | kyuubi.session.engine.share.level | USER | (deprecated) - Using kyuubi.engine.share.level instead | string | 1.0.0 | +| kyuubi.session.engine.shutdown.watchdog.timeout | PT1M | The maximum time to wait for the engine to shutdown gracefully before forcing termination. When an engine shutdown is initiated, this watchdog timer starts counting down. If the engine doesn't complete shutdown within this timeout period, it will be forcefully terminated to prevent hanging. Set to 0 or a negative value to disable the forced shutdown mechanism. | duration | 1.11.0 | | kyuubi.session.engine.spark.initialize.sql || The initialize sql for Spark session. It fallback to `kyuubi.engine.session.initialize.sql` | seq | 1.8.1 | | kyuubi.session.engine.spark.main.resource | <undefined> | The package used to create Spark SQL engine remote application. If it is undefined, Kyuubi will use the default | string | 1.0.0 | | kyuubi.session.engine.spark.max.initial.wait | PT1M | Max wait time for the initial connection to Spark engine. The engine will self-terminate no new incoming connection is established within this time. This setting only applies at the CONNECTION share level. 0 or negative means not to self-terminate. | duration | 1.8.0 | diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala index 02d2a7afb59..98bda27fb68 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark import java.time.Instant import java.util.{Locale, UUID} import java.util.concurrent.{CountDownLatch, ScheduledExecutorService, ThreadPoolExecutor, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.concurrent.duration.Duration import scala.util.control.NonFatal @@ -45,17 +45,17 @@ import org.apache.kyuubi.ha.HighAvailabilityConf._ import org.apache.kyuubi.ha.client.RetryPolicies import org.apache.kyuubi.service.Serverable import org.apache.kyuubi.session.SessionHandle -import org.apache.kyuubi.util.{JavaUtils, SignalRegister, ThreadUtils} +import org.apache.kyuubi.util.{JavaUtils, SignalRegister, ThreadDumpUtils, ThreadUtils} import org.apache.kyuubi.util.ThreadUtils.scheduleTolerableRunnableWithFixedDelay - case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngine") { - override val backendService = new SparkSQLBackendService(spark) override val frontendServices = Seq(new SparkTBinaryFrontendService(this)) private val shutdown = new AtomicBoolean(false) private val gracefulStopDeregistered = new AtomicBoolean(false) - + @volatile private var watchdogThreadRef: AtomicReference[Thread] = new AtomicReference[Thread]() + private val EMERGENCY_SHUTDOWN_EXIT_CODE = 99 + private val WATCHDOG_ERROR_EXIT_CODE = 98 @volatile private var lifetimeTerminatingChecker: Option[ScheduledExecutorService] = None @volatile private var stopEngineExec: Option[ThreadPoolExecutor] = None private lazy val engineSavePath = @@ -98,6 +98,7 @@ case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngin } override def stop(): Unit = if (shutdown.compareAndSet(false, true)) { + startShutdownWatchdog() super.stop() lifetimeTerminatingChecker.foreach(checker => { val shutdownTimeout = conf.get(ENGINE_EXEC_POOL_SHUTDOWN_TIMEOUT) @@ -121,6 +122,7 @@ case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngin } def gracefulStop(): Unit = if (gracefulStopDeregistered.compareAndSet(false, true)) { + startShutdownWatchdog() val stopTask: Runnable = () => { if (!shutdown.get) { info(s"Spark engine is de-registering from engine discovery space.") @@ -212,6 +214,76 @@ case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngin TimeUnit.MILLISECONDS) } } + + /** + * Starts a shutdown watchdog thread as a failsafe mechanism. + * + * This thread monitors the shutdown process and will forcefully terminate + * the JVM if graceful shutdown takes too long. This prevents zombie processes + * caused by non-daemon threads that refuse to terminate. + */ + private def startShutdownWatchdog(): Unit = { + if (org.apache.kyuubi.Utils.isTesting) { + info("Shutdown Watchdog is disabled in test mode.") + return + } + + val shutdownWatchdogTimeout = conf.get(ENGINE_SHUTDOWN_WATCHDOG_TIMEOUT) + if (shutdownWatchdogTimeout <= 0) { + info("Shutdown Watchdog is disabled (timeout <= 0).") + return + } + + // Prevent multiple watchdog threads + watchdogThreadRef.synchronized { + if (watchdogThreadRef.get() != null) { + warn("Shutdown Watchdog is already running, ignoring duplicate start request") + return + } + } + + info(s"Shutdown Watchdog activated. Engine will be forcefully terminated if graceful " + + s"shutdown exceeds ${shutdownWatchdogTimeout} ms.") + + val watchdogThread = new Thread("shutdown-watchdog") { + override def run(): Unit = { + debug("Shutdown Watchdog thread started, monitoring graceful shutdown process") + try { + TimeUnit.MILLISECONDS.sleep(shutdownWatchdogTimeout) + + error(s"EMERGENCY SHUTDOWN TRIGGERED") + error(s"Graceful shutdown exceeded ${shutdownWatchdogTimeout} ms timeout") + error(s"Non-daemon threads are preventing JVM exit") + error(s"Initiating forced termination...") + + // Thread dump for diagnostics + error(s"=== THREAD DUMP FOR DIAGNOSTIC ===") + ThreadDumpUtils.dumpToLogger(logger) + error(s"=== END OF THREAD DUMP ===") + + error(s"Forcefully terminating JVM now...") + System.exit(EMERGENCY_SHUTDOWN_EXIT_CODE) + + } catch { + case _: InterruptedException => + warn("Shutdown Watchdog: Normal shutdown detected, watchdog exiting.") + case t: Throwable => + error( + s"Shutdown Watchdog error: ${t.getClass.getSimpleName}: ${t.getMessage}") + t.printStackTrace(System.err) + error("Proceeding with emergency termination...") + System.exit(WATCHDOG_ERROR_EXIT_CODE) // Watchdog error + } + } + } + + watchdogThread.setDaemon(true) + watchdogThread.start() + watchdogThreadRef.set(watchdogThread) + + debug(s"Shutdown Watchdog thread started: ${watchdogThread.getName}") + } + } object SparkSQLEngine extends Logging { @@ -407,6 +479,7 @@ object SparkSQLEngine extends Logging { startEngine(spark) // blocking main thread countDownLatch.await() + currentEngine.foreach(_.startShutdownWatchdog()) } catch { case e: KyuubiException => currentEngine match { diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index b1589811c92..c9c710cc540 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -1784,6 +1784,17 @@ object KyuubiConf { .timeConf .createWithDefault(Duration.ofMinutes(30L).toMillis) + val ENGINE_SHUTDOWN_WATCHDOG_TIMEOUT: ConfigEntry[Long] = + buildConf("kyuubi.session.engine.shutdown.watchdog.timeout") + .doc("The maximum time to wait for the engine to shutdown gracefully before " + + "forcing termination. When an engine shutdown is initiated, this watchdog " + + "timer starts counting down. If the engine doesn't complete shutdown within " + + "this timeout period, it will be forcefully terminated to prevent hanging. " + + "Set to 0 or a negative value to disable the forced shutdown mechanism.") + .version("1.11.0") + .timeConf + .createWithDefault(Duration.ofMinutes(1L).toMillis) + val SESSION_CONF_IGNORE_LIST: ConfigEntry[Set[String]] = buildConf("kyuubi.session.conf.ignore.list") .doc("A comma-separated list of ignored keys. If the client connection contains any of" + diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadDumpUtils.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadDumpUtils.scala new file mode 100644 index 00000000000..beb033a720e --- /dev/null +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadDumpUtils.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.util + +import java.io.{PrintStream, PrintWriter, StringWriter} +import java.lang.management.{ManagementFactory, ThreadInfo} +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.slf4j.Logger + +/** + * Utility for generating comprehensive JVM thread dumps with various configuration options. + * + * This utility provides detailed thread information including: + * - Thread states and statistics + * - Deadlock detection + * - Lock information + * - Stack traces with configurable depth + * - Separate handling of daemon vs non-daemon threads + */ +object ThreadDumpUtils { + + /** + * Configuration for thread dump generation + * + * @param stackDepth Maximum number of stack frames to show per thread (0 = unlimited) + * @param showDaemonThreads Whether to include daemon threads in the output + * @param includeLocksInfo Whether to include lock and monitor information + * @param includeSynchronizers Whether to include ownable synchronizers information + * @param sortThreadsBy How to sort threads in the output + */ + case class DumpConfig( + stackDepth: Int = 10, + showDaemonThreads: Boolean = true, + includeLocksInfo: Boolean = true, + includeSynchronizers: Boolean = true, + sortThreadsBy: ThreadSortBy = ThreadSortBy.Name) + + /** + * Thread sorting options + */ + sealed trait ThreadSortBy + object ThreadSortBy { + case object Id extends ThreadSortBy + case object Name extends ThreadSortBy + case object State extends ThreadSortBy + } + + private val DefaultConfig: DumpConfig = DumpConfig() + + /** + * Enhanced thread information combining ThreadInfo with Thread object + * to provide more comprehensive thread details + */ + private case class ExtendedThreadInfo( + threadInfo: ThreadInfo, + thread: Option[Thread] // Maybe None if thread has died between collection points + ) { + def isDaemon: Boolean = thread.exists(_.isDaemon) + def getName: String = threadInfo.getThreadName + def getId: Long = threadInfo.getThreadId + def getState: Thread.State = threadInfo.getThreadState + } + + /** + * Collects all thread information by merging data from ThreadMXBean and active Thread objects + */ + private def getAllExtendedThreadInfo( + includeLocksInfo: Boolean, + includeSynchronizers: Boolean): Array[ExtendedThreadInfo] = { + val threadBean = ManagementFactory.getThreadMXBean + val allThreadInfos = threadBean.dumpAllThreads(includeLocksInfo, includeSynchronizers) + + // Create a map of active threads by ID for additional thread properties + val activeThreadsMap: Map[Long, Thread] = Try { + Thread.getAllStackTraces.keySet().asScala.map(t => t.getId -> t).toMap + }.getOrElse(Map.empty) + + // Merge ThreadInfo with Thread object to get complete information + allThreadInfos.map { threadInfo => + val thread = activeThreadsMap.get(threadInfo.getThreadId) + ExtendedThreadInfo(threadInfo, thread) + } + } + + /** + * Dumps thread information to console (System.err) + */ + def dumpToConsole(config: DumpConfig = DefaultConfig): Unit = { + dumpToStream(System.err, config) + } + + /** + * Generates thread dump as a string + */ + def dumpToString(config: DumpConfig = DefaultConfig): String = { + val stringWriter = new StringWriter(8192) + val printWriter = new PrintWriter(stringWriter) + try { + dumpToWriter(printWriter, config) + stringWriter.toString + } finally { + printWriter.close() + } + } + + /** + * Dumps thread information to a PrintStream + */ + private def dumpToStream(out: PrintStream, config: DumpConfig): Unit = { + val printWriter = new PrintWriter(out, true) + dumpToWriter(printWriter, config) + } + + /** + * Dumps thread information to a SLF4J Logger + */ + def dumpToLogger(logger: Logger, config: DumpConfig = DefaultConfig): Unit = { + try { + val dump = dumpToString(config) + logger.error("\n" + dump) + } catch { + case t: Throwable => + t.printStackTrace(System.err) + dumpToConsole(config) + } + } + + /** + * Core method that generates the formatted thread dump output + */ + private def dumpToWriter(writer: PrintWriter, config: DumpConfig): Unit = { + // scalastyle:off println + def writeLine(line: String = ""): Unit = writer.println(line) + // scalastyle:on println + + try { + val allExtendedThreads = + getAllExtendedThreadInfo(config.includeLocksInfo, config.includeSynchronizers) + val timestamp = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")) + + // Header section + writeLine("================== Thread Dump Start ==================") + writeLine(s"Timestamp: $timestamp") + writeLine(s"Total threads: ${allExtendedThreads.length}") + writeLine() + + // Critical: Check for deadlocks first + checkDeadlocks(writeLine) + + // Overview statistics + showThreadStatistics(allExtendedThreads, writeLine) + + // Separate threads by daemon status - non-daemon threads are more critical + val (nonDaemonThreads, daemonThreads) = allExtendedThreads.partition(!_.isDaemon) + + // Non-daemon threads section (these prevent JVM shutdown) + writeLine() + writeLine("==================== Non-Daemon Threads ====================") + writeLine("(These threads prevent JVM from exiting)") + showThreadDetails(nonDaemonThreads, config, writeLine) + + // Daemon threads section (optional) + if (config.showDaemonThreads) { + writeLine() + writeLine("====================== Daemon Threads ======================") + showThreadDetails(daemonThreads, config, writeLine) + } + + // Summary table + writeLine() + writeLine("======================== Summary ========================") + showThreadSummary(allExtendedThreads, writeLine) + writeLine("================== Thread Dump End ==================") + + } catch { + case t: Throwable => + writeLine(s"*** ERROR: Failed to generate thread dump: ${t.getMessage} ***") + t.printStackTrace(writer) + + // Emergency fallback - provide basic thread information + performEmergencyDump(writeLine) + } + } + + /** + * Emergency fallback method when main thread dump fails + */ + private def performEmergencyDump(writeLine: String => Unit): Unit = { + Try { + val basicThreads = Thread.getAllStackTraces.keySet().toArray(new Array[Thread](0)) + writeLine(s"*** Emergency fallback: Found ${basicThreads.length} threads ***") + basicThreads.foreach { thread => + val threadType = if (thread.isDaemon) "daemon" else "user" + writeLine(s"Thread: ${thread.getName} [${thread.getState}] $threadType") + } + }.recover { case ex => + writeLine(s"*** Even emergency fallback failed: ${ex.getMessage} ***") + } + } + + /** + * Detects and reports deadlocks + */ + private def checkDeadlocks(writeLine: String => Unit): Unit = { + Try { + val threadBean = ManagementFactory.getThreadMXBean + val deadlockedThreads = threadBean.findDeadlockedThreads() + if (deadlockedThreads != null && deadlockedThreads.nonEmpty) { + writeLine("*** DEADLOCK DETECTED ***") + writeLine(s"Deadlocked thread IDs: ${deadlockedThreads.mkString(", ")}") + writeLine("") + } + }.recover { case ex => + writeLine(s"Warning: Could not check for deadlocks: ${ex.getMessage}") + } + } + + /** + * Shows high-level thread statistics and state distribution + */ + private def showThreadStatistics( + allThreads: Array[ExtendedThreadInfo], + writeLine: String => Unit): Unit = { + val threadsByState = allThreads.groupBy(_.getState.toString) + val (nonDaemonThreads, daemonThreads) = allThreads.partition(!_.isDaemon) + + writeLine("Thread Statistics:") + writeLine(f" Non-daemon threads: ${nonDaemonThreads.length}%3d") + writeLine(f" Daemon threads: ${daemonThreads.length}%3d") + writeLine("") + writeLine("Threads by state:") + + // Sort states alphabetically for consistent output + threadsByState.toSeq.sortBy(_._1).foreach { case (state, threads) => + writeLine(f" $state%-15s: ${threads.length}%3d") + } + } + + /** + * Shows detailed information for each thread including stack traces + */ + private def showThreadDetails( + threads: Array[ExtendedThreadInfo], + config: DumpConfig, + writeLine: String => Unit): Unit = { + if (threads.isEmpty) { + writeLine(" (No threads in this category)") + return + } + + // Sort threads according to configuration + val sortedThreads = config.sortThreadsBy match { + case ThreadSortBy.Id => threads.sortBy(_.getId) + case ThreadSortBy.Name => threads.sortBy(_.getName) + case ThreadSortBy.State => threads.sortBy(t => (t.getState.toString, t.getName)) + } + + sortedThreads.foreach { extThreadInfo => + val threadInfo = extThreadInfo.threadInfo + + // Thread header + val daemonLabel = if (extThreadInfo.isDaemon) "daemon" else "" + writeLine("") + writeLine( + s"""Thread: "${threadInfo.getThreadName}" #${threadInfo.getThreadId} $daemonLabel""") + writeLine(s" State: ${threadInfo.getThreadState}") + + // Lock information (if enabled and available) + if (config.includeLocksInfo) { + Option(threadInfo.getLockName).foreach { lockName => + writeLine(s" Waiting on: <$lockName>") + } + Option(threadInfo.getLockOwnerName).foreach { ownerName => + writeLine(s""" Lock owned by "${ownerName}" #${threadInfo.getLockOwnerId}""") + } + } + + // Stack trace with depth limit + val stackTrace = threadInfo.getStackTrace + val actualDepth = if (config.stackDepth <= 0) stackTrace.length + else math.min(config.stackDepth, stackTrace.length) + + stackTrace.take(actualDepth).foreach { element => + writeLine(s" at $element") + } + + // Indicate if stack trace was truncated + if (stackTrace.length > actualDepth) { + writeLine(s" ... (${stackTrace.length - actualDepth} more stack frames)") + } + } + } + + /** + * Shows a compact summary table of all threads + */ + private def showThreadSummary( + allThreads: Array[ExtendedThreadInfo], + writeLine: String => Unit): Unit = { + writeLine("Thread ID | Type | State | Name") + writeLine("----------|------|-----------------|---------------------------") + + allThreads.sortBy(_.getId).foreach { extThreadInfo => + val threadType = if (extThreadInfo.isDaemon) "D" else "U" // D=Daemon, U=User + val line = + f"${extThreadInfo.getId}%8d | $threadType%4s | " + + f"${extThreadInfo.getState}%-15s | ${extThreadInfo.getName}" + writeLine(line) + } + } +} diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadDumpUtilsSuite.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadDumpUtilsSuite.scala new file mode 100644 index 00000000000..00cc266d597 --- /dev/null +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadDumpUtilsSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.util + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.kyuubi.KyuubiFunSuite + +class ThreadDumpUtilsSuite extends KyuubiFunSuite { + + test("dumpToString should return non-empty thread dump") { + val dump = ThreadDumpUtils.dumpToString() + assert(dump != null && dump.nonEmpty) + assert(dump.contains("Thread Dump Start")) + assert(dump.contains("Thread Dump End")) + } + + test("dumpToConsole should print to provided stream") { + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos) + val oldErr = System.err + try { + System.setErr(ps) + ThreadDumpUtils.dumpToConsole() + ps.flush() + } finally { + System.setErr(oldErr) + } + val output = baos.toString("UTF-8") + assert(output.contains("Thread Dump Start")) + assert(output.contains("Thread Dump End")) + } + + test("dumpToString should respect showDaemonThreads config") { + val config = ThreadDumpUtils.DumpConfig(showDaemonThreads = false) + val dump = ThreadDumpUtils.dumpToString(config) + assert(dump.contains("Thread Dump Start")) + assert(dump.contains("Thread Dump End")) + assert(!dump.contains("====================== Daemon Threads ======================")) + } + + /** + * Helper to find a specific thread's section in the dump. + * This is useful for targeted assertions on a single thread's output. + */ + private def findThreadSection(dump: String, threadName: String): Option[String] = { + val pattern = s"""(?s)Thread: "$threadName".*?(?=\\n\\nThread:|\\n\\n================)""".r + pattern.findFirstIn(dump) + } + + test("dumpToString should correctly limit stack depth and show truncation message") { + val readyLatch = new CountDownLatch(1) + var deepStackThread: Thread = null + val stackDepth = 100 + val limit = 10 + + // This function is INTENTIONALLY NOT tail-recursive. + // By performing an operation after the recursive call, we prevent the compiler + // from optimizing it into a loop, thus forcing a deep stack trace. + def deepStack(n: Int): Int = { + if (n > 0) { + val res = deepStack(n - 1) + res + 1 // This operation breaks tail-call optimization + } else { + readyLatch.countDown() // Signal that the thread has reached the bottom of the stack + try { + Thread.sleep(5000) // Keep thread alive so we can dump it + } catch { case _: InterruptedException => Thread.currentThread().interrupt() } + 0 // Base case return + } + } + + try { + deepStackThread = new Thread(() => deepStack(stackDepth), "deep-stack-test-thread") + deepStackThread.setDaemon(true) + deepStackThread.start() + + assert(readyLatch.await(5, TimeUnit.SECONDS), "Test thread did not initialize in time") + + // 1. Get the full, unlimited dump to determine the actual total stack depth. + val unlimitedConfig = ThreadDumpUtils.DumpConfig(stackDepth = 0) + val unlimitedDump = ThreadDumpUtils.dumpToString(unlimitedConfig) + val unlimitedSection = findThreadSection(unlimitedDump, "deep-stack-test-thread") + assert( + unlimitedSection.isDefined, + "Thread 'deep-stack-test-thread' not found in unlimited dump") + val totalFrames = unlimitedSection.get.linesIterator.count(_.trim.startsWith("at ")) + assert(totalFrames > stackDepth, "Full stack depth is not as deep as expected.") + + // 2. Get the limited dump and verify its contents against the full one. + val limitedConfig = ThreadDumpUtils.DumpConfig(stackDepth = limit) + val limitedDump = ThreadDumpUtils.dumpToString(limitedConfig) + val limitedSection = findThreadSection(limitedDump, "deep-stack-test-thread") + assert(limitedSection.isDefined, "Thread 'deep-stack-test-thread' not found in limited dump") + + // Verify the number of "at" lines matches the configured limit. + val stackTraceLines = limitedSection.get.linesIterator.count(_.trim.startsWith("at ")) + assert(stackTraceLines == limit) + + // Verify the truncation message is present and mathematically correct. + val expectedMoreFrames = totalFrames - limit + assert( + limitedSection.get.contains(s"... (${expectedMoreFrames} more stack frames)"), + s"Dump did not contain the expected truncation message. " + + s"Expected '... ($expectedMoreFrames more stack frames)'.") + + } finally { + if (deepStackThread != null) deepStackThread.interrupt() + } + } + + test("dumpToString should sort threads by ID when configured") { + val config = ThreadDumpUtils.DumpConfig(sortThreadsBy = ThreadDumpUtils.ThreadSortBy.Id) + val dump = ThreadDumpUtils.dumpToString(config) + + // Extract the summary table for easier parsing + val summarySection = + dump.substring(dump.indexOf("======================== Summary ========================")) + + // Regex to extract thread IDs from the summary table lines + val idPattern = """^\s*(\d+)\s*\|.*""".r + val ids = summarySection.linesIterator.flatMap { line => + idPattern.findFirstMatchIn(line).map(_.group(1).toLong) + }.toList + + assert(ids.nonEmpty, "No thread IDs found in the summary") + // Verify that the list of IDs is sorted, which proves the sorting logic worked + assert(ids == ids.sorted, s"Thread IDs are not sorted: $ids") + } + + test("dumpToString should detect and report deadlocks") { + val lock1 = new Object() + val lock2 = new Object() + // Latch to ensure both threads are in a deadlock state before we take the dump + val deadlockSetupLatch = new CountDownLatch(2) + + val thread1 = new Thread( + () => { + lock1.synchronized { + deadlockSetupLatch.countDown() + Thread.sleep(200) // Wait for thread2 to acquire lock2 + lock2.synchronized { + // This line will never be reached + } + } + }, + "kyuubi-deadlock-thread-1") + + val thread2 = new Thread( + () => { + lock2.synchronized { + deadlockSetupLatch.countDown() + Thread.sleep(200) // Wait for thread1 to acquire lock1 + lock1.synchronized { + // This line will never be reached + } + } + }, + "kyuubi-deadlock-thread-2") + + // Use daemon threads so they don't block JVM exit if the test fails + thread1.setDaemon(true) + thread2.setDaemon(true) + + try { + thread1.start() + thread2.start() + + // Wait for both threads to acquire their first lock + assert( + deadlockSetupLatch.await(5, TimeUnit.SECONDS), + "Deadlock condition was not met in time") + + // Give the JVM time to officially recognize the deadlock state + Thread.sleep(500) + + val dump = ThreadDumpUtils.dumpToString() + + assert(dump.contains("*** DEADLOCK DETECTED ***")) + // Check that both threads involved in the deadlock are mentioned + assert(dump.contains(""""kyuubi-deadlock-thread-1"""")) + assert(dump.contains(""""kyuubi-deadlock-thread-2"""")) + // Check for lock details which are crucial for debugging deadlocks + assert(dump.contains("Waiting on:")) + assert(dump.contains("Lock owned by")) + + } finally { + // Clean up the threads + thread1.interrupt() + thread2.interrupt() + } + } +}