Skip to content

Commit 70c3ede

Browse files
committed
feat: Enhance ContextBuilder with folding and token management
- Implement function folding for large files - Add token count management for context building - Improve context generation for different AI models
1 parent 72c3e12 commit 70c3ede

File tree

1 file changed

+187
-53
lines changed

1 file changed

+187
-53
lines changed

src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt

Lines changed: 187 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ import ai.devchat.storage.RecentFilesTracker
1212
import com.intellij.psi.PsiFile
1313
import com.intellij.psi.util.PsiUtilCore.getPsiFile
1414
import ai.devchat.storage.CONFIG
15+
import com.intellij.openapi.application.ApplicationManager
16+
import java.util.concurrent.atomic.AtomicInteger
17+
import com.intellij.psi.PsiElement
18+
import com.intellij.psi.PsiRecursiveElementVisitor
19+
import java.util.concurrent.ConcurrentHashMap
20+
1521

1622
val MAX_CONTEXT_TOKENS: Int
1723
get() = (CONFIG["complete_context_limit"] as? Int) ?: 6000
@@ -74,6 +80,10 @@ data class CodeSnippet (
7480
)
7581

7682
class ContextBuilder(val file: PsiFile, val offset: Int) {
83+
private val CURSOR_MARKER = "<<<CURSOR>>>"
84+
private val foldedContentCache = ConcurrentHashMap<Int, Pair<String, Int>>()
85+
private val foldCounter = AtomicInteger(0)
86+
7787
val filepath: String = file.virtualFile.path
7888
val content: String by lazy {
7989
ReadAction.compute<String, Throwable> {
@@ -95,30 +105,124 @@ class ContextBuilder(val file: PsiFile, val offset: Int) {
95105

96106
private fun buildFileContext(): Pair<String, String> {
97107
val maxTokens = MAX_CONTEXT_TOKENS * 0.35
108+
val contentLines = content.lines()
109+
110+
if (contentLines.size <= 1000 && content.tokenCount() > maxTokens) {
111+
val (foldedContent, markerOffset) = getFoldedContent()
112+
val (prefix, suffix) = buildContextFromFoldedContent(foldedContent, markerOffset)
113+
return adjustContextTokens(prefix, suffix)
114+
} else {
115+
val (prefix, suffix) = buildOriginalContext()
116+
return adjustContextTokens(prefix, suffix)
117+
}
118+
}
119+
120+
private fun adjustContextTokens(prefix: String, suffix: String): Pair<String, String> {
121+
val totalTokens = prefix.tokenCount() + suffix.tokenCount()
122+
if (totalTokens <= MAX_CONTEXT_TOKENS) {
123+
return Pair(prefix, suffix)
124+
}
125+
126+
val prefixTokens = prefix.tokenCount()
127+
val suffixTokens = suffix.tokenCount()
128+
129+
return when {
130+
prefixTokens <= MAX_CONTEXT_TOKENS / 2 -> {
131+
val newSuffixLength = MAX_CONTEXT_TOKENS - prefixTokens
132+
Pair(prefix, suffix.take(newSuffixLength))
133+
}
134+
suffixTokens <= MAX_CONTEXT_TOKENS / 2 -> {
135+
val newPrefixLength = MAX_CONTEXT_TOKENS - suffixTokens
136+
Pair(prefix.takeLast(newPrefixLength), suffix)
137+
}
138+
else -> {
139+
val halfMaxTokens = MAX_CONTEXT_TOKENS / 2
140+
Pair(prefix.takeLast(halfMaxTokens), suffix.take(halfMaxTokens))
141+
}
142+
}
143+
}
144+
145+
private fun getFoldedContent(): Pair<String, Int> {
146+
val foldId = foldCounter.getAndIncrement()
147+
return foldedContentCache.computeIfAbsent(foldId) {
148+
val contentWithMarker = insertCursorMarker(content, offset)
149+
val foldedContent = foldFunctions(contentWithMarker)
150+
val markerOffset = foldedContent.indexOf(CURSOR_MARKER)
151+
foldedContent.replace(CURSOR_MARKER, "") to markerOffset
152+
}
153+
}
154+
155+
private fun insertCursorMarker(text: String, offset: Int): String {
156+
return text.substring(0, offset) + CURSOR_MARKER + text.substring(offset)
157+
}
158+
159+
private fun foldFunctions(text: String): String {
160+
return ReadAction.compute<String, Throwable> {
161+
val psiFile = PsiDocumentManager.getInstance(file.project).getPsiFile(file.viewProvider.document!!)
162+
val foldInfoList = mutableListOf<FoldInfo>()
163+
val cursorOffset = text.indexOf(CURSOR_MARKER)
164+
val markerLength = CURSOR_MARKER.length
165+
166+
psiFile?.accept(object : PsiRecursiveElementVisitor() {
167+
override fun visitElement(element: PsiElement) {
168+
if (isFunctionElement(element)) {
169+
val start = element.textRange.startOffset
170+
val end = element.textRange.endOffset
171+
val adjustedStart = adjustOffset(start, cursorOffset, markerLength)
172+
val adjustedEnd = adjustOffset(end, cursorOffset, markerLength)
173+
174+
if (!elementContainsCursor(adjustedStart, adjustedEnd, cursorOffset)) {
175+
val foldedText = element.foldTextOfLevel(1)
176+
foldInfoList.add(FoldInfo(adjustedStart, adjustedEnd, foldedText))
177+
}
178+
} else {
179+
super.visitElement(element)
180+
}
181+
}
182+
})
183+
184+
// 按照结束位置降序排序,确保从后向前替换
185+
foldInfoList.sortByDescending { it.end }
186+
187+
val sb = StringBuilder(text)
188+
for (foldInfo in foldInfoList) {
189+
sb.replace(foldInfo.start, foldInfo.end, foldInfo.foldedText)
190+
}
98191

99-
val maxPrefixTokens = (maxTokens * 0.7).toInt()
100-
var prefixTokens = 0
101-
val prefixStart = content.lineSequenceReversed(offset).takeWhile {(_, line) ->
102-
val numTokens = line.tokenCount()
103-
if (prefixTokens + numTokens > maxPrefixTokens) return@takeWhile false
104-
prefixTokens += numTokens
105-
true
106-
}.lastOrNull()?.first?.first ?: 0
107-
tokenCount += prefixTokens
108-
109-
val maxSuffixTokens = maxTokens - prefixTokens
110-
var suffixTokens = 0
111-
val suffixEnd = content.lineSequence(offset).takeWhile {(_, line) ->
112-
val numTokens = line.tokenCount()
113-
if (suffixTokens + numTokens > maxSuffixTokens) return@takeWhile false
114-
suffixTokens += numTokens
115-
true
116-
}.lastOrNull()?.first?.last ?: content.length
117-
tokenCount += suffixTokens
192+
sb.toString()
193+
}
194+
}
195+
196+
private data class FoldInfo(val start: Int, val end: Int, val foldedText: String)
197+
198+
private fun adjustOffset(offset: Int, cursorOffset: Int, markerLength: Int): Int {
199+
return if (offset > cursorOffset) offset + markerLength else offset
200+
}
201+
202+
private fun elementContainsCursor(start: Int, end: Int, cursorOffset: Int): Boolean {
203+
return cursorOffset in start until end
204+
}
118205

206+
private fun isFunctionElement(element: PsiElement): Boolean {
207+
// 这里需要根据你的语言特性来判断是否为函数元素
208+
// 例如,可以检查元素的类型或结构
209+
// Log.info("elementType: ${element.node.elementType.toString()}")
210+
return element.node.elementType.toString() == "FUNCTION" ||
211+
element.node.elementType.toString() == "METHOD" ||
212+
element.node.elementType.toString() == "FUN"
213+
}
214+
215+
private fun buildContextFromFoldedContent(foldedContent: String, markerOffset: Int): Pair<String, String> {
119216
return Pair(
120-
content.substring(prefixStart, offset),
121-
content.substring(offset, suffixEnd)
217+
foldedContent.substring(0, markerOffset),
218+
foldedContent.substring(markerOffset, foldedContent.length)
219+
)
220+
}
221+
222+
private fun buildOriginalContext(): Pair<String, String> {
223+
return Pair(
224+
content.substring(0, offset),
225+
content.substring(offset, content.length)
122226
)
123227
}
124228

@@ -145,34 +249,36 @@ class ContextBuilder(val file: PsiFile, val offset: Int) {
145249
}
146250

147251
private fun buildSymbolsContext(): String {
148-
return runInEdtAndGet {
252+
return ApplicationManager.getApplication().runReadAction<String> {
149253
Log.info("Starting buildSymbolsContext")
150254
val element = file.findElementAt(offset)
151255
Log.info("Found element at offset: ${element?.text}")
152256

153257
val variables = element?.findAccessibleVariables() ?: emptySequence()
154-
val variablesCount = variables.count()
258+
259+
// 使用 toList() 来触发惰性序列的计算,确保在 Read Action 中完成
260+
val variablesList = variables.toList()
261+
val variablesCount = variablesList.size
155262
Log.info("Found $variablesCount accessible variables")
156263

157264
val processedTypes = mutableSetOf<String>()
158265
val result = StringBuilder()
159266

160-
variables
161-
.onEach { Log.info("Processing variable: ${it.symbol.name}") }
162-
.forEach { variable ->
163-
val typeElement = variable.typeDeclaration.element
164-
val isLocalType = typeElement.containingFile.virtualFile.path == filepath
165-
val typeText = limitTypeText(typeElement.text)
166-
Log.info("Variable ${variable.symbol.name} type: ${typeText}")
167-
Log.info("Is local type: $isLocalType")
168-
169-
val typeFilePath = typeElement.containingFile.virtualFile.path
170-
Log.info("Actual type file: $typeFilePath")
171-
172-
// 如果typeFilePath表示了系统库的定义,那么不应该添加到上下文中,例如string的定义。
173-
if (isValidTypePath(typeFilePath)) {
174-
val typeKey = "${typeElement.text}:$typeFilePath"
175-
if (!processedTypes.contains(typeKey)) {
267+
variablesList.forEach { variable ->
268+
val typeElement = variable.typeDeclaration.element
269+
val isLocalType = typeElement.containingFile.virtualFile.path == filepath
270+
val typeText = limitTypeText(typeElement.text)
271+
val typeFilePath = typeElement.containingFile.virtualFile.path
272+
val typeKey = "${typeElement.text}:$typeFilePath"
273+
274+
Log.info("Processing variable ${variable.symbol.name}")
275+
Log.info("Is local type: $isLocalType")
276+
if (isValidTypePath(typeFilePath)) {
277+
if (!processedTypes.contains(typeKey)) {
278+
if (!isLocalType) {
279+
Log.info("Variable ${variable.symbol.name} type: $typeText")
280+
Log.info("Actual type file: $typeFilePath")
281+
176282
processedTypes.add(typeKey)
177283

178284
val snippet = CodeSnippet(
@@ -194,13 +300,18 @@ class ContextBuilder(val file: PsiFile, val offset: Int) {
194300
.append("$commentPrefix <definition>\n$commentedContent\n\n\n\n")
195301
} else {
196302
Log.info("Skipping type ${variable.symbol.name} due to token limit")
197-
return@runInEdtAndGet result.toString()
303+
return@forEach
198304
}
199305
} else {
200-
Log.info("Skipping duplicate type: ${typeText}")
306+
Log.info("Skipping type ${variable.symbol.name} due to local definition")
201307
}
308+
} else {
309+
Log.info("Skipping duplicate type: ${variable.symbol.name}")
202310
}
311+
} else {
312+
Log.info("Skipping invalid type path: ${typeFilePath}")
203313
}
314+
}
204315

205316
Log.info("buildSymbolsContext result length: ${result.length}")
206317
result.toString()
@@ -249,19 +360,42 @@ class ContextBuilder(val file: PsiFile, val offset: Int) {
249360

250361
fun createPrompt(model: String?): String {
251362
val (prefix, suffix) = buildFileContext()
252-
val extras: String = listOf(
253-
// taskDescriptionContextWithCommentPrefix,
254-
// neighborFileContext,
255-
buildCalleeDefinitionsContext(),
256-
buildSymbolsContext(),
257-
buildRecentFilesContext(),
258-
// similarBlockContext,
259-
// gitDiffContext,
260-
).joinToString("")
261-
262-
return if (!model.isNullOrEmpty() && model.contains("deepseek"))
363+
var currentTokenCount = prefix.tokenCount() + suffix.tokenCount()
364+
val maxAllowedTokens = (MAX_CONTEXT_TOKENS * 0.9).toInt()
365+
366+
val extraContexts = mutableListOf<String>()
367+
368+
Log.info("Current token count: $currentTokenCount")
369+
Log.info("Max allowed tokens: $maxAllowedTokens")
370+
if (currentTokenCount < maxAllowedTokens) {
371+
val contextBuilders = listOf(
372+
::buildCalleeDefinitionsContext,
373+
::buildSymbolsContext,
374+
::buildRecentFilesContext
375+
)
376+
377+
for (builder in contextBuilders) {
378+
val context = builder()
379+
val contextTokens = context.tokenCount()
380+
if (currentTokenCount + contextTokens <= maxAllowedTokens) {
381+
extraContexts.add(context)
382+
currentTokenCount += contextTokens
383+
} else {
384+
break
385+
}
386+
}
387+
}
388+
389+
val extras = extraContexts.joinToString("")
390+
391+
return if (!model.isNullOrEmpty() && model.contains("deepseek")) {
263392
"<|fim▁begin|>$extras$commentPrefix<filename>$filepath\n\n$prefix<|fim▁hole|>$suffix<|fim▁end|>"
264-
else
393+
} else if (!model.isNullOrEmpty() && model.contains("starcoder")) {
394+
"<fim_prefix>$extras$commentPrefix<filename>$filepath\n\n$prefix<fim_suffix>$suffix<fim_middle>"
395+
} else if (!model.isNullOrEmpty() && model.contains("codestral")) {
396+
"<s>[SUFFIX]$suffix[PREFIX]$extras$commentPrefix<filename>$filepath\n\n$prefix"
397+
} else {
265398
"<fim_prefix>$extras$commentPrefix<filename>$filepath\n\n$prefix<fim_suffix>$suffix<fim_middle>"
399+
}
266400
}
267401
}

0 commit comments

Comments
 (0)