Skip to content
28 changes: 25 additions & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import org.jetbrains.intellij.platform.gradle.IntelliJPlatformType
import org.jetbrains.intellij.platform.gradle.TestFrameworkType
import org.jetbrains.intellij.platform.gradle.models.ProductRelease
import org.jetbrains.intellij.platform.gradle.tasks.RunIdeTask
import org.jetbrains.intellij.platform.gradle.tasks.aware.SplitModeAware.SplitModeTarget
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import java.io.FileOutputStream
import java.net.URL
Expand Down Expand Up @@ -80,7 +81,10 @@ if (spaceCredentialsProvided()) {
dependencies {
add(hasGrazieAccess.implementationConfigurationName, kotlin("stdlib"))
add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion")
add(
hasGrazieAccess.implementationConfigurationName,
"org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion"
)
}

tasks.register("checkCredentials") {
Expand Down Expand Up @@ -428,6 +432,7 @@ fun String?.orDefault(default: String): String = this ?: default
* @param prompt a txt file containing the LLM's prompt template
* @param out The output directory for the project.
* @param enableCoverage flag to enable/disable coverage computation
* @param methodName indicates the name of the method under test or empty for class level generation
*/
tasks.create<RunIdeTask>("headless") {
val root: String? by project
Expand All @@ -440,8 +445,22 @@ tasks.create<RunIdeTask>("headless") {
val prompt: String? by project
val out: String? by project
val enableCoverage: String? by project

args = listOfNotNull("testspark", root, file, cut, cp, junitv, llm, token, prompt, out, enableCoverage.orDefault("false"))
val methodName: String? by project

args = listOfNotNull(
"testspark",
root,
file,
cut,
cp,
junitv,
llm,
token,
prompt,
out,
enableCoverage.orDefault("false"),
methodName.orDefault("")
)

jvmArgs(
"-Xmx16G",
Expand All @@ -450,6 +469,9 @@ tasks.create<RunIdeTask>("headless") {
"java.base/jdk.internal.vm=ALL-UNNAMED",
"-Didea.system.path",
)

splitMode = false
splitModeTarget = SplitModeTarget.BACKEND
}

fun spaceCredentialsProvided() = spaceUsername.isNotEmpty() && spacePassword.isNotEmpty()
7 changes: 4 additions & 3 deletions runTestSparkHeadless.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ if [ $# -ne "12" ]; then
9) Output directory
10) Enable/disable coverage computation ('true' or 'false')
11) Space username
12) Space password"
12) Space password
13) Method under test name(or empty for class-level generation)"
exit 1
fi

echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}"
"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}"
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@ import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiManager
import com.intellij.psi.PsiMethod
import kotlinx.serialization.ExperimentalSerializationApi
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle
import org.jetbrains.research.testspark.core.data.JUnitVersion
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestCompiler
import org.jetbrains.research.testspark.core.test.data.CodeType
import org.jetbrains.research.testspark.data.FragmentToTestData
import org.jetbrains.research.testspark.data.ProjectContext
import org.jetbrains.research.testspark.data.llm.JsonEncoding
import org.jetbrains.research.testspark.java.JavaPsiMethodWrapper
import org.jetbrains.research.testspark.kotlin.KotlinPsiHelperProvider
import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator
import org.jetbrains.research.testspark.services.LLMSettingsService
Expand Down Expand Up @@ -71,8 +76,13 @@ class TestSparkStarter : ApplicationStarter {
val output = args[9]
// Run coverage
val runCoverage = args[10].toBoolean()
// Method under test name(or empty string for class level generation)
val methodName = args[11]

val testsExecutionResultManager = TestsExecutionResultManager()
// TODO check for suitable refactoring
val language =
if (cutSourceFilePath.toString().endsWith(".kt")) SupportedLanguage.Kotlin else SupportedLanguage.Java

println("Test generation requested for $projectPath")

Expand Down Expand Up @@ -108,10 +118,15 @@ class TestSparkStarter : ApplicationStarter {
println("Couldn't open file $cutSourceFilePath")
exitProcess(1)
}

// get target PsiClass
val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile) as PsiJavaFile
val targetPsiClass = detectPsiClass(psiFile.classes, classUnderTestName) ?: run {
val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile)
val targetPsiClass = detectPsiClass(
when (language) {
SupportedLanguage.Java -> psiFile as PsiJavaFile
SupportedLanguage.Kotlin -> psiFile as KtFile
}.classes,
classUnderTestName
) ?: run {
println("Couldn't find $classUnderTestName in $cutSourceFilePath")
exitProcess(1)
}
Expand Down Expand Up @@ -159,7 +174,10 @@ class TestSparkStarter : ApplicationStarter {
val packageName = packageList.joinToString(".")

// Get PsiHelper
val psiHelper = PsiHelperProvider.getPsiHelper(psiFile)
val psiHelper = when (language) {
SupportedLanguage.Kotlin -> KotlinPsiHelperProvider().getPsiHelper(psiFile as KtFile)
SupportedLanguage.Java -> PsiHelperProvider.getPsiHelper(psiFile as PsiJavaFile)
}
if (psiHelper == null) {
// TODO exception: the support for the current language does not exist
}
Expand All @@ -183,9 +201,23 @@ class TestSparkStarter : ApplicationStarter {
psiHelper.language,
projectSDKPath.toString(),
)
val codeType = when (methodName) {
"" -> FragmentToTestData(CodeType.CLASS)
else -> {
val psiMethod = targetPsiClass.methods.find { it.name == methodName } ?: run {
println("Couldn't find method $methodName")
exitProcess(1)
}
FragmentToTestData(
CodeType.METHOD,
psiHelper.generateMethodDescriptor(JavaPsiMethodWrapper(psiMethod as PsiMethod))
)
}
}

val uiContext = llmProcessManager.runTestGenerator(
indicator,
FragmentToTestData(CodeType.CLASS),
codeType,
packageName,
projectContext,
testGenerationData,
Expand Down Expand Up @@ -257,6 +289,7 @@ class TestSparkStarter : ApplicationStarter {
val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}"
println("Run tests in $targetDirectory")
File(targetDirectory).walk().forEach {
// TODO Doesn't work for compiled kotlin files
if (it.name.endsWith(".class")) {
println("Running test ${it.name}")
var testcaseName = it.nameWithoutExtension.removePrefix("Generated")
Expand Down
Loading