Skip to content

Commit 9d8bc8a

Browse files
committed
refactor(search): improve symbol search and references
Improves symbol search algorithm by using trigram-based similarity matching instead of simple term matching. This provides better accuracy for fuzzy matches and handles compound words better. Key changes: - Replace basic term matching with trigram-based Jaccard similarity - Normalize and weight symbol matches separately from filename matches - Add reference list display in chat UI - Move authentication status messages to more logical locations - Fix various edge cases and improve code organization Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
1 parent 754c971 commit 9d8bc8a

5 files changed

Lines changed: 140 additions & 94 deletions

File tree

lua/CopilotChat/client.lua

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
---@field temperature number?
99
---@field on_progress nil|fun(response: string):nil
1010

11+
local async = require('plenary.async')
1112
local log = require('plenary.log')
1213
local tiktoken = require('CopilotChat.tiktoken')
1314
local notify = require('CopilotChat.notify')
@@ -187,21 +188,6 @@ local function generate_embedding_request(inputs, threshold)
187188
end, inputs)
188189
end
189190

190-
local function generate_references(references)
191-
local out = ''
192-
193-
for _, reference in ipairs(references) do
194-
out = out .. '\n[' .. reference.name .. '](' .. reference.url .. ')'
195-
end
196-
197-
if out == '' then
198-
return out
199-
end
200-
201-
out = '\n\n**`References:`**' .. out
202-
return out
203-
end
204-
205191
---@class CopilotChat.Client : Class
206192
---@field providers table<string, CopilotChat.Provider>
207193
---@field provider_cache table<string, table>
@@ -235,7 +221,6 @@ function Client:authenticate(provider_name)
235221
if not headers or (expires_at and expires_at <= math.floor(os.time())) then
236222
local token
237223
if provider.get_token then
238-
notify.publish(notify.STATUS, 'Authenticating to provider ' .. provider_name)
239224
token, expires_at = provider.get_token()
240225
end
241226

@@ -260,8 +245,8 @@ function Client:fetch_models()
260245
for _, provider_name in ipairs(provider_order) do
261246
local provider = self.providers[provider_name]
262247
if not provider.disabled and provider.get_models then
263-
local headers = self:authenticate(provider_name)
264248
notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name)
249+
local headers = self:authenticate(provider_name)
265250
local ok, provider_models = pcall(provider.get_models, headers)
266251
if ok then
267252
for _, model in ipairs(provider_models) do
@@ -296,8 +281,8 @@ function Client:fetch_agents()
296281
for _, provider_name in ipairs(provider_order) do
297282
local provider = self.providers[provider_name]
298283
if not provider.disabled and provider.get_agents then
299-
local headers = self:authenticate(provider_name)
300284
notify.publish(notify.STATUS, 'Fetching agents from ' .. provider_name)
285+
local headers = self:authenticate(provider_name)
301286
local ok, provider_agents = pcall(provider.get_agents, headers)
302287
if ok then
303288
for _, agent in ipairs(provider_agents) do
@@ -320,6 +305,7 @@ end
320305
--- Ask a question to Copilot
321306
---@param prompt string: The prompt to send to Copilot
322307
---@param opts CopilotChat.Client.ask: Options for the request
308+
---@return string, table, number, number
323309
function Client:ask(prompt, opts)
324310
opts = opts or {}
325311
prompt = vim.trim(prompt)
@@ -369,10 +355,13 @@ function Client:ask(prompt, opts)
369355
log.debug('Tokenizer: ', tokenizer)
370356
tiktoken.load(tokenizer)
371357

358+
notify.publish(notify.STATUS, 'Generating request')
359+
360+
async.util.scheduler()
372361
local references = {}
373362
for _, embed in ipairs(embeddings) do
374363
table.insert(references, {
375-
name = embed.filename,
364+
name = utils.filename(embed.filename),
376365
url = embed.filename,
377366
})
378367
end
@@ -540,6 +529,8 @@ function Client:ask(prompt, opts)
540529
parse_stream_line(line, job)
541530
end
542531

532+
notify.publish(notify.STATUS, 'Thinking')
533+
543534
opts.agent = opts.agent and opts.agent:gsub(':' .. provider_name .. '$', '')
544535
opts.model = opts.model:gsub(':' .. provider_name .. '$', '')
545536
local headers = self:authenticate(provider_name)
@@ -560,12 +551,10 @@ function Client:ask(prompt, opts)
560551
args.stream = stream_func
561552
end
562553

563-
notify.publish(notify.STATUS, 'Thinking')
564-
565554
local response, err = utils.curl_post(provider.get_url(opts), args)
566555

567556
if self.current_job ~= job_id then
568-
return nil, nil, nil
557+
return
569558
end
570559

571560
self.current_job = nil
@@ -628,20 +617,12 @@ function Client:ask(prompt, opts)
628617
return
629618
end
630619

631-
local full_references = generate_references(references)
632-
log.info('References: ', full_references)
633-
if full_references ~= '' then
634-
full_response = full_response .. full_references
635-
if on_progress then
636-
on_progress(full_references)
637-
end
638-
end
639-
640620
log.trace('Full response: ', full_response)
641621
log.debug('Last message: ', last_message)
642622

643623
return full_response,
644-
last_message and last_message.usage and last_message.usage.total_tokens,
624+
references,
625+
last_message and last_message.usage and last_message.usage.total_tokens or 0,
645626
max_tokens
646627
end
647628

lua/CopilotChat/context.lua

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
---@field outline string?
1515
---@field symbols table<string, CopilotChat.context.symbol>?
1616
---@field embedding table<number>?
17-
---@field score?
17+
---@field score number?
1818

1919
local async = require('plenary.async')
2020
local log = require('plenary.log')
@@ -66,8 +66,8 @@ local OFF_SIDE_RULE_LANGUAGES = {
6666
'fsharp',
6767
}
6868

69-
local MIN_SYMBOL_SIMILARITY = 0.4 -- Symbol-based matching can be more lenient
70-
local MIN_SEMANTIC_SIMILARITY = 0.3 -- Cosine similarity should be stricter for relevance
69+
local MIN_SYMBOL_SIMILARITY = 0.3
70+
local MIN_SEMANTIC_SIMILARITY = 0.4
7171
local MULTI_FILE_THRESHOLD = 5
7272
local MAX_FILES = 2500
7373

@@ -115,65 +115,106 @@ local function data_ranked_by_relatedness(query, data, min_similarity)
115115
return results
116116
end
117117

118-
--- Rank data by symbols
119-
---@param prompt string
120-
---@param data table<CopilotChat.context.embed>
121-
---@param min_similarity number
122-
---@return table<CopilotChat.context.embed>
123-
local function data_ranked_by_symbols(prompt, data, min_similarity)
124-
local query_terms = {}
125-
for term in prompt:lower():gmatch('%w+') do
126-
query_terms[term] = true
118+
-- Create trigrams from text (e.g., "hello" -> {"hel", "ell", "llo"})
119+
local function get_trigrams(text)
120+
local trigrams = {}
121+
text = text:lower()
122+
for i = 1, #text - 2 do
123+
trigrams[text:sub(i, i + 2)] = true
124+
end
125+
return trigrams
126+
end
127+
128+
-- Calculate Jaccard similarity between two trigram sets
129+
local function trigram_similarity(set1, set2)
130+
local intersection = 0
131+
local union = 0
132+
133+
-- Count intersection and union
134+
for trigram in pairs(set1) do
135+
if set2[trigram] then
136+
intersection = intersection + 1
137+
end
138+
union = union + 1
139+
end
140+
141+
for trigram in pairs(set2) do
142+
if not set1[trigram] then
143+
union = union + 1
144+
end
145+
end
146+
147+
return intersection / union
148+
end
149+
150+
local function data_ranked_by_symbols(query, data, min_similarity)
151+
-- Get query trigrams including compound versions
152+
local query_trigrams = {}
153+
154+
-- Add trigrams for each word
155+
for term in query:lower():gmatch('%w+') do
156+
for trigram in pairs(get_trigrams(term)) do
157+
query_trigrams[trigram] = true
158+
end
159+
end
160+
161+
-- Add trigrams for compound query
162+
local compound_query = query:lower():gsub('[^%w]', '')
163+
for trigram in pairs(get_trigrams(compound_query)) do
164+
query_trigrams[trigram] = true
127165
end
128166

129167
local results = {}
130-
for _, entry in ipairs(data) do
131-
local total_terms = 0
132-
local matched_terms = 0
133-
local filename = entry.filename and entry.filename:lower() or ''
134-
135-
-- Calculate similarity score based on term matches
136-
for term in pairs(query_terms) do
137-
total_terms = total_terms + 1
138-
139-
-- Filename matches
140-
if filename:find(term, 1, true) then
141-
matched_terms = matched_terms + 1
142-
if vim.fn.fnamemodify(filename, ':t'):gsub('%..*$', '') == term then
143-
matched_terms = matched_terms + 0.5 -- Bonus for exact filename match
144-
end
145-
end
168+
local max_score = 0
146169

147-
-- Symbol matches
148-
if entry.symbols then
149-
for _, symbol in ipairs(entry.symbols) do
150-
if symbol.name and symbol.name:lower():find(term, 1, true) then
151-
matched_terms = matched_terms + 1
152-
if symbol.name:lower() == term then
153-
matched_terms = matched_terms + 0.5 -- Bonus for exact symbol match
154-
end
155-
end
156-
if symbol.signature and symbol.signature:lower():find(term, 1, true) then
157-
matched_terms = matched_terms + 0.5 -- Partial credit for signature matches
158-
end
170+
for _, entry in ipairs(data) do
171+
local score = 0
172+
local basename = vim.fn.fnamemodify(entry.filename, ':t'):gsub('%..*$', '')
173+
174+
-- Get trigrams for basename and compound version
175+
local file_trigrams = get_trigrams(basename)
176+
local compound_trigrams = get_trigrams(basename:gsub('[^%w]', ''))
177+
178+
-- Calculate similarities
179+
local name_sim = trigram_similarity(query_trigrams, file_trigrams)
180+
local compound_sim = trigram_similarity(query_trigrams, compound_trigrams)
181+
182+
-- Take best match
183+
score = math.max(name_sim, compound_sim)
184+
185+
-- Add symbol matches
186+
if entry.symbols then
187+
local symbol_score = 0
188+
for _, symbol in ipairs(entry.symbols) do
189+
if symbol.name then
190+
local symbol_trigrams = get_trigrams(symbol.name)
191+
local sym_sim = trigram_similarity(query_trigrams, symbol_trigrams)
192+
symbol_score = math.max(symbol_score, sym_sim)
159193
end
160194
end
195+
score = score + (symbol_score * 0.5) -- Weight symbol matches less
161196
end
162197

163-
-- Calculate similarity score (0 to 1 range)
164-
local similarity = matched_terms / (total_terms * 2) -- Denominator accounts for potential bonuses
198+
if score > 0 then
199+
max_score = math.max(max_score, score)
200+
table.insert(results, vim.tbl_extend('force', entry, { score = score }))
201+
end
202+
end
165203

166-
-- Only include results above similarity threshold
167-
if similarity >= min_similarity then
168-
table.insert(results, vim.tbl_extend('force', entry, { score = similarity }))
204+
-- Normalize and filter results
205+
local filtered_results = {}
206+
for _, result in ipairs(results) do
207+
result.score = result.score / max_score
208+
if result.score >= min_similarity then
209+
table.insert(filtered_results, result)
169210
end
170211
end
171212

172-
table.sort(results, function(a, b)
213+
table.sort(filtered_results, function(a, b)
173214
return a.score > b.score
174215
end)
175216

176-
return results
217+
return filtered_results
177218
end
178219

179220
--- Get the full signature of a declaration
@@ -603,6 +644,8 @@ function M.filter_embeddings(prompt, model, embeddings)
603644
return embeddings
604645
end
605646

647+
notify.publish(notify.STATUS, 'Ranking embeddings')
648+
606649
-- Rank embeddings by symbols
607650
embeddings = data_ranked_by_symbols(prompt, embeddings, MIN_SYMBOL_SIMILARITY)
608651
log.debug('Ranked data:', #embeddings)

lua/CopilotChat/init.lua

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -659,21 +659,22 @@ function M.ask(prompt, config)
659659
return
660660
end
661661

662-
local ask_ok, response, token_count, token_max_count = pcall(client.ask, client, prompt, {
663-
history = history,
664-
selection = selection,
665-
embeddings = filtered_embeddings,
666-
system_prompt = system_prompt,
667-
model = selected_model,
668-
agent = selected_agent,
669-
temperature = config.temperature,
670-
on_progress = vim.schedule_wrap(function(token)
671-
if not config.headless then
672-
state.chat:append(token)
673-
end
674-
has_output = true
675-
end),
676-
})
662+
local ask_ok, response, references, token_count, token_max_count =
663+
pcall(client.ask, client, prompt, {
664+
history = history,
665+
selection = selection,
666+
embeddings = filtered_embeddings,
667+
system_prompt = system_prompt,
668+
model = selected_model,
669+
agent = selected_agent,
670+
temperature = config.temperature,
671+
on_progress = vim.schedule_wrap(function(token)
672+
if not config.headless then
673+
state.chat:append(token)
674+
end
675+
has_output = true
676+
end),
677+
})
677678

678679
async.util.scheduler()
679680

@@ -691,6 +692,7 @@ function M.ask(prompt, config)
691692

692693
if not config.headless then
693694
state.last_response = response
695+
state.chat.references = references
694696
state.chat.token_count = token_count
695697
state.chat.token_max_count = token_max_count
696698
end

lua/CopilotChat/ui/chat.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ end
6565
---@field spinner CopilotChat.ui.Spinner
6666
---@field sections table<CopilotChat.ui.Chat.Section>
6767
---@field config CopilotChat.config.shared
68+
---@field references table
6869
---@field token_count number?
6970
---@field token_max_count number?
7071
local Chat = class(function(self, question_header, answer_header, separator, help, on_buf_create)
@@ -82,6 +83,7 @@ local Chat = class(function(self, question_header, answer_header, separator, hel
8283

8384
-- Variables
8485
self.config = {}
86+
self.references = {}
8587
self.token_count = nil
8688
self.token_max_count = nil
8789
end, Overlay)
@@ -250,6 +252,16 @@ function Chat:render()
250252
msg = msg .. self.token_count .. '/' .. self.token_max_count .. ' tokens used'
251253
end
252254

255+
if self.references and #self.references > 0 then
256+
if msg ~= '' then
257+
msg = msg .. '\n'
258+
end
259+
msg = msg .. '\nReferences:\n'
260+
for _, ref in ipairs(self.references) do
261+
msg = msg .. ' ' .. ref.name .. '\n'
262+
end
263+
end
264+
253265
self:show_help(msg, last_section.start_line - last_section.end_line - 1)
254266
else
255267
self:clear_help()
@@ -438,6 +450,7 @@ end
438450

439451
function Chat:clear()
440452
self:validate()
453+
self.references = {}
441454
self.token_count = nil
442455
self.token_max_count = nil
443456
vim.bo[self.bufnr].modifiable = true

0 commit comments

Comments
 (0)