Skip to content

Commit 22e2ecd

Browse files
authored
Merge pull request #212 from devchat-ai/code_completion_local_var
feat: Enhance code completion, error handling, and performance
2 parents bc823d3 + 3f9925f commit 22e2ecd

File tree

7 files changed

+648
-220
lines changed

7 files changed

+648
-220
lines changed

src/main/kotlin/ai/devchat/common/IDEUtils.kt

Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,39 @@ import kotlinx.coroutines.runBlocking
1818
import java.util.concurrent.CompletableFuture
1919
import java.util.concurrent.CountDownLatch
2020
import kotlin.system.measureTimeMillis
21+
import com.intellij.psi.util.PsiTreeUtil
22+
import com.intellij.codeInsight.navigation.actions.GotoTypeDeclarationAction
23+
import com.intellij.openapi.fileEditor.FileEditorManager
24+
import java.lang.ref.SoftReference
25+
import java.util.concurrent.ConcurrentHashMap
26+
import java.util.concurrent.locks.ReentrantReadWriteLock
27+
import kotlin.concurrent.read
28+
import kotlin.concurrent.write
29+
import com.intellij.psi.SmartPointerManager
30+
import com.intellij.psi.SmartPsiElementPointer
2131

2232

2333
object IDEUtils {
34+
private const val MAX_CACHE_SIZE = 1000
35+
private data class CacheEntry(val filePath: String, val offset: Int, val element: SoftReference<SymbolTypeDeclaration>)
36+
37+
private val variableCache = object : LinkedHashMap<String, CacheEntry>(MAX_CACHE_SIZE, 0.75f, true) {
38+
override fun removeEldestEntry(eldest: Map.Entry<String, CacheEntry>): Boolean {
39+
return size > MAX_CACHE_SIZE
40+
}
41+
}
42+
private val cacheLock = ReentrantReadWriteLock()
43+
44+
private data class FoldCacheEntry(
45+
val foldedText: String,
46+
val elementPointer: SmartPsiElementPointer<PsiElement>,
47+
val elementLength: Int,
48+
val elementHash: Int
49+
)
50+
51+
private val foldCache = ConcurrentHashMap<String, SoftReference<FoldCacheEntry>>()
52+
53+
2454
fun <T> runInEdtAndGet(block: () -> T): T {
2555
val app = ApplicationManager.getApplication()
2656
if (app.isDispatchThread) {
@@ -127,21 +157,120 @@ object IDEUtils {
127157
)
128158

129159
fun PsiElement.findAccessibleVariables(): Sequence<SymbolTypeDeclaration> {
130-
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
131-
return generateSequence(this.parent) { it.parent }
132-
.takeWhile { it !is PsiFile }
133-
.flatMap { it.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>() }
134-
.plus(this.containingFile.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>())
135-
.filter { !it.name.isNullOrEmpty() && it.nameIdentifier != null }
136-
.mapNotNull {
137-
val typeDeclaration = it.getTypeDeclaration() ?: return@mapNotNull null
138-
val virtualFile = typeDeclaration.containingFile.virtualFile ?: return@mapNotNull null
139-
val isProjectContent = projectFileIndex.isInContent(virtualFile)
140-
SymbolTypeDeclaration(it, CodeNode(typeDeclaration, isProjectContent))
160+
val projectFileIndex = ProjectFileIndex.getInstance(project)
161+
162+
// 首先收集所有可能的变量
163+
val allVariables = sequence {
164+
var currentScope: PsiElement? = this@findAccessibleVariables
165+
while (currentScope != null && currentScope !is PsiFile) {
166+
val variablesInScope = PsiTreeUtil.findChildrenOfAnyType(
167+
currentScope,
168+
false,
169+
PsiNameIdentifierOwner::class.java
170+
)
171+
172+
for (variable in variablesInScope) {
173+
if (isLikelyVariable(variable) && !variable.name.isNullOrEmpty() && variable.nameIdentifier != null) {
174+
yield(variable)
175+
}
176+
}
177+
178+
currentScope = currentScope.parent
179+
}
180+
181+
yieldAll(this@findAccessibleVariables.containingFile.children
182+
.asSequence()
183+
.filterIsInstance<PsiNameIdentifierOwner>()
184+
.filter { isLikelyVariable(it) && !it.name.isNullOrEmpty() && it.nameIdentifier != null })
185+
}.distinct()
186+
187+
// 处理这些变量的类型,使用缓存
188+
return allVariables.mapNotNull { variable ->
189+
val cacheKey = "${variable.containingFile?.virtualFile?.path}:${variable.textRange.startOffset}"
190+
191+
getCachedOrCompute(cacheKey, variable)
192+
}
193+
}
194+
195+
private fun getCachedOrCompute(cacheKey: String, variable: PsiElement): SymbolTypeDeclaration? {
196+
cacheLock.read {
197+
variableCache[cacheKey]?.let { entry ->
198+
entry.element.get()?.let { cached ->
199+
if (cached.symbol.isValid) return cached
200+
}
141201
}
202+
}
203+
204+
val computed = computeSymbolTypeDeclaration(variable) ?: return null
205+
206+
cacheLock.write {
207+
variableCache[cacheKey] = CacheEntry(
208+
variable.containingFile?.virtualFile?.path ?: return null,
209+
variable.textRange.startOffset,
210+
SoftReference(computed)
211+
)
212+
}
213+
214+
return computed
215+
}
216+
217+
private fun computeSymbolTypeDeclaration(variable: PsiElement): SymbolTypeDeclaration? {
218+
val typeDeclaration = getTypeElement(variable) ?: return null
219+
val virtualFile = variable.containingFile?.virtualFile ?: return null
220+
val isProjectContent = ProjectFileIndex.getInstance(variable.project).isInContent(virtualFile)
221+
return SymbolTypeDeclaration(variable as PsiNameIdentifierOwner, CodeNode(typeDeclaration, isProjectContent))
222+
}
223+
224+
// 辅助函数,用于判断一个元素是否可能是变量
225+
private fun isLikelyVariable(element: PsiElement): Boolean {
226+
val elementClass = element.javaClass.simpleName
227+
return elementClass.contains("Variable", ignoreCase = true) ||
228+
elementClass.contains("Parameter", ignoreCase = true) ||
229+
elementClass.contains("Field", ignoreCase = true)
230+
}
231+
232+
// 辅助函数,用于获取变量的类型元素
233+
private fun getTypeElement(element: PsiElement): PsiElement? {
234+
return ReadAction.compute<PsiElement?, Throwable> {
235+
val project = element.project
236+
val editor = FileEditorManager.getInstance(project).selectedTextEditor ?: return@compute null
237+
val offset = element.textOffset
238+
239+
GotoTypeDeclarationAction.findSymbolType(editor, offset)
240+
}
142241
}
143242

144243
fun PsiElement.foldTextOfLevel(foldingLevel: Int = 1): String {
244+
var result: String
245+
val executionTime = measureTimeMillis {
246+
val cacheKey = "${containingFile.virtualFile.path}:${textRange.startOffset}:$foldingLevel"
247+
248+
// 检查缓存
249+
result = foldCache[cacheKey]?.get()?.let { cachedEntry ->
250+
val cachedElement = cachedEntry.elementPointer.element
251+
if (cachedElement != null && cachedElement.isValid &&
252+
text.length == cachedEntry.elementLength &&
253+
text.hashCode() == cachedEntry.elementHash) {
254+
cachedEntry.foldedText
255+
} else null
256+
} ?: run {
257+
// 如果缓存无效或不存在,重新计算
258+
val foldedText = computeFoldedText(foldingLevel)
259+
// 更新缓存
260+
val elementPointer = SmartPointerManager.getInstance(project).createSmartPsiElementPointer(this)
261+
foldCache[cacheKey] = SoftReference(FoldCacheEntry(foldedText, elementPointer, text.length, text.hashCode()))
262+
foldedText
263+
}
264+
}
265+
266+
// 记录执行时间
267+
Log.info("foldTextOfLevel execution time: $executionTime ms")
268+
269+
// 返回计算结果
270+
return result
271+
}
272+
273+
private fun PsiElement.computeFoldedText(foldingLevel: Int): String {
145274
val file = this.containingFile
146275
val document = file.viewProvider.document ?: return text
147276
val fileNode = file.node ?: return text

src/main/kotlin/ai/devchat/plugin/IDEServer.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ class IDEServer(private var project: Project): Disposable {
368368

369369
fun stop() {
370370
Log.info("Stopping IDE server...")
371-
Notifier.info("Stopping IDE server...")
372371
server?.stop(1_000, 2_000)
373372
}
374373

src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -156,37 +156,52 @@ class Agent(val scope: CoroutineScope) {
156156
}
157157
}
158158

159-
private fun requestDevChatAPI(prompt: String): Flow<CodeCompletionChunk> = flow {
160-
val devChatEndpoint = CONFIG["providers.devchat.api_base"] as? String
161-
val devChatAPIKey = CONFIG["providers.devchat.api_key"] as? String
162-
val endpoint = "$devChatEndpoint/completions"
163-
val endingChunk = "[DONE]"
164-
val payload = mapOf(
165-
"model" to ((CONFIG["complete_model"] as? String) ?: defaultCompletionModel),
166-
"prompt" to prompt,
167-
"stream" to true,
168-
"stop" to listOf("<|endoftext|>", "<|EOT|>", "<file_sep>", "```", "/", "\n\n"),
169-
"temperature" to 0.2
170-
)
171-
val requestBody = gson.toJson(payload).toRequestBody("application/json; charset=utf-8".toMediaType())
172-
val requestBuilder = Request.Builder().url(endpoint).post(requestBody)
173-
requestBuilder.addHeader("Authorization", "Bearer $devChatAPIKey")
174-
requestBuilder.addHeader("Accept", "text/event-stream")
175-
requestBuilder.addHeader("Content-Type", "application/json")
176-
httpClient.newCall(requestBuilder.build()).execute().use { response ->
177-
if (!response.isSuccessful) throw IllegalArgumentException("Unexpected code $response")
178-
response.body?.charStream()?.buffered()?.use {reader ->
179-
reader.lineSequence().asFlow()
180-
.filter {it.isNotEmpty()}
181-
.takeWhile { it.startsWith("data:") }
182-
.map { it.drop(5).trim() }
183-
.takeWhile { it.uppercase() != endingChunk }
184-
.map { gson.fromJson(it, CompletionResponseChunk::class.java) }
185-
.takeWhile {it != null}
186-
.collect { emit(CodeCompletionChunk(it.id, it.choices[0].text!!)) }
159+
private fun requestDevChatAPI(prompt: String): Flow<CodeCompletionChunk> = flow {
160+
val devChatEndpoint = CONFIG["providers.devchat.api_base"] as? String
161+
val devChatAPIKey = CONFIG["providers.devchat.api_key"] as? String
162+
val endpoint = "$devChatEndpoint/completions"
163+
val endingChunk = "[DONE]"
164+
val payload = mapOf(
165+
"model" to ((CONFIG["complete_model"] as? String) ?: defaultCompletionModel),
166+
"prompt" to prompt,
167+
"stream" to true,
168+
"stop" to listOf("<|endoftext|>", "<|EOT|>", "<file_sep>", "```", "/", "\n\n"),
169+
"temperature" to 0.2
170+
)
171+
val requestBody = gson.toJson(payload).toRequestBody("application/json; charset=utf-8".toMediaType())
172+
val requestBuilder = Request.Builder().url(endpoint).post(requestBody)
173+
requestBuilder.addHeader("Authorization", "Bearer $devChatAPIKey")
174+
requestBuilder.addHeader("Accept", "text/event-stream")
175+
requestBuilder.addHeader("Content-Type", "application/json")
176+
177+
httpClient.newCall(requestBuilder.build()).execute().use { response ->
178+
if (!response.isSuccessful) {
179+
val errorBody = response.body?.string() ?: "No error body"
180+
when (response.code) {
181+
500 -> {
182+
if (errorBody.contains("Insufficient Balance")) {
183+
logger.warn("DevChat API error: Insufficient balance. Please check your account.")
184+
} else {
185+
logger.warn("DevChat API server error. Response code: ${response.code}. Body: $errorBody")
186+
}
187+
}
188+
else -> logger.warn("Unexpected response from DevChat API. Code: ${response.code}. Body: $errorBody")
187189
}
190+
return@flow
191+
}
192+
193+
response.body?.charStream()?.buffered()?.use { reader ->
194+
reader.lineSequence().asFlow()
195+
.filter { it.isNotEmpty() }
196+
.takeWhile { it.startsWith("data:") }
197+
.map { it.drop(5).trim() }
198+
.takeWhile { it.uppercase() != endingChunk }
199+
.map { gson.fromJson(it, CompletionResponseChunk::class.java) }
200+
.takeWhile { it != null }
201+
.collect { emit(CodeCompletionChunk(it.id, it.choices[0].text!!)) }
188202
}
189203
}
204+
}
190205

191206
private fun toLines(chunks: Flow<CodeCompletionChunk>): Flow<CodeCompletionChunk> = flow {
192207
var ongoingLine = ""
@@ -299,7 +314,7 @@ suspend fun provideCompletions(
299314
val llmRequestElapse = System.currentTimeMillis() - startTime
300315
val offset = completionRequest.position
301316
val replaceRange = CompletionResponse.Choice.Range(start = offset, end = offset)
302-
val text = if (completion.text != prevCompletion) completion.text else ""
317+
val text = completion.text
303318
val choice = CompletionResponse.Choice(index = 0, text = text, replaceRange = replaceRange)
304319
val response = CompletionResponse(completion.id, model, listOf(choice), promptBuildingElapse, llmRequestElapse)
305320

src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,63 @@ import com.intellij.psi.PsiFile
1010
import io.ktor.util.*
1111
import kotlinx.coroutines.CoroutineScope
1212
import kotlinx.coroutines.Dispatchers
13+
import kotlinx.coroutines.withContext
14+
import com.intellij.openapi.application.ApplicationManager
15+
import com.intellij.openapi.application.ModalityState
16+
import kotlinx.coroutines.suspendCancellableCoroutine
17+
import kotlin.coroutines.resume
18+
import kotlinx.coroutines.CancellationException
1319

1420
@Service
1521
class AgentService : Disposable {
1622
val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
1723
private var agent: Agent = Agent(scope)
1824

19-
suspend fun provideCompletion(editor: Editor, offset: Int, manually: Boolean = false): Agent.CompletionResponse? {
20-
return ReadAction.compute<PsiFile, Throwable> {
21-
editor.project?.let { project ->
22-
PsiDocumentManager.getInstance(project).getPsiFile(editor.document)
23-
}
24-
}?.let { file ->
25-
agent.provideCompletions(
26-
Agent.CompletionRequest(
27-
file,
28-
file.getLanguageId(),
29-
offset,
30-
manually,
31-
)
32-
)
25+
suspend fun provideCompletion(editor: Editor, offset: Int, manually: Boolean = false): Agent.CompletionResponse? {
26+
println("Entering provideCompletion method")
27+
return withContext(Dispatchers.Default) {
28+
try {
29+
println("Attempting to get PsiFile")
30+
val file = suspendCancellableCoroutine<PsiFile?> { continuation ->
31+
ApplicationManager.getApplication().invokeLater({
32+
val psiFile = ReadAction.compute<PsiFile?, Throwable> {
33+
editor.project?.let { project ->
34+
PsiDocumentManager.getInstance(project).getPsiFile(editor.document)
35+
}
36+
}
37+
continuation.resume(psiFile)
38+
}, ModalityState.defaultModalityState())
39+
}
40+
41+
println("PsiFile obtained: ${file != null}")
42+
43+
file?.let { psiFile ->
44+
println("Calling agent.provideCompletions")
45+
val result = agent.provideCompletions(
46+
Agent.CompletionRequest(
47+
psiFile,
48+
psiFile.getLanguageId(),
49+
offset,
50+
manually,
51+
)
52+
)
53+
println("agent.provideCompletions returned: $result")
54+
result
55+
}
56+
} catch (e: CancellationException) {
57+
// 方案1:以较低的日志级别记录
58+
println("Completion was cancelled: ${e.message}")
59+
// 或者方案2:完全忽略
60+
// // 不做任何处理
61+
62+
null
63+
} catch (e: Exception) {
64+
println("Exception in provideCompletion: ${e.message}")
65+
e.printStackTrace()
66+
null
67+
}
3368
}
34-
}
69+
}
3570

3671
suspend fun postEvent(event: Agent.LogEventRequest) {
3772
agent.postEvent(event)

0 commit comments

Comments
 (0)