Skip to content

Commit 2dde388

Browse files
committed
refactor: make client module stateless
The client module has been refactored to be stateless by making it a singleton and moving provider initialization to a separate method. This simplifies the overall architecture and makes the code more maintainable. Additional changes: - Improved reference handling and generation - Updated header naming convention for Copilot requests - Fixed embedding context resolution to pass proper params - Removed redundant client state from main module
1 parent 49aae4d commit 2dde388

5 files changed

Lines changed: 90 additions & 65 deletions

File tree

lua/CopilotChat/client.lua

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ local function generate_embedding_request(inputs, threshold)
187187
end, inputs)
188188
end
189189

190+
local function generate_references(references)
191+
local out = ''
192+
193+
for _, reference in ipairs(references) do
194+
out = out .. '\n[' .. reference.name .. '](' .. reference.url .. ')'
195+
end
196+
197+
if out == '' then
198+
return out
199+
end
200+
201+
out = '\n\n**`References:`**' .. out
202+
return out
203+
end
204+
190205
---@class CopilotChat.Client : Class
191206
---@field providers table<string, CopilotChat.Provider>
192207
---@field provider_cache table<string, table>
@@ -198,17 +213,12 @@ end
198213
---@field token table?
199214
---@field sessionid string?
200215
---@field machineid string
201-
local Client = class(function(self, providers)
202-
self.providers = providers
216+
local Client = class(function(self)
217+
self.providers = {}
203218
self.embedding_cache = {}
204219
self.models = nil
205220
self.agents = nil
206-
207221
self.provider_cache = {}
208-
for provider_name, _ in pairs(providers) do
209-
self.provider_cache[provider_name] = {}
210-
end
211-
212222
self.current_job = nil
213223
self.expires_at = nil
214224
self.headers = nil
@@ -359,6 +369,14 @@ function Client:ask(prompt, opts)
359369
log.debug('Tokenizer: ', tokenizer)
360370
tiktoken.load(tokenizer)
361371

372+
local references = {}
373+
for _, embed in ipairs(embeddings) do
374+
table.insert(references, {
375+
name = embed.filename,
376+
url = embed.filename,
377+
})
378+
end
379+
362380
local generated_messages = {}
363381
local selection_messages = generate_selection_messages(selection)
364382
local embeddings_messages = generate_embeddings_messages(embeddings)
@@ -408,7 +426,6 @@ function Client:ask(prompt, opts)
408426
local errored = false
409427
local finished = false
410428
local full_response = ''
411-
local full_references = ''
412429

413430
local function finish_stream(err, job)
414431
if err then
@@ -450,14 +467,10 @@ function Client:ask(prompt, opts)
450467
for _, reference in ipairs(content.copilot_references) do
451468
local metadata = reference.metadata
452469
if metadata and metadata.display_name and metadata.display_url then
453-
full_references = full_references
454-
.. '\n'
455-
.. '['
456-
.. metadata.display_name
457-
.. ']'
458-
.. '('
459-
.. metadata.display_url
460-
.. ')'
470+
table.insert(references, {
471+
name = metadata.display_name,
472+
url = metadata.display_url,
473+
})
461474
end
462475
end
463476
end
@@ -615,8 +628,9 @@ function Client:ask(prompt, opts)
615628
return
616629
end
617630

631+
local full_references = generate_references(references)
632+
log.info('References: ', full_references)
618633
if full_references ~= '' then
619-
full_references = '\n\n**`References:`**' .. full_references
620634
full_response = full_response .. full_references
621635
if on_progress then
622636
on_progress(full_references)
@@ -833,4 +847,13 @@ function Client:running()
833847
return self.current_job ~= nil
834848
end
835849

836-
return Client
850+
--- Load providers to client
851+
function Client:load_providers(providers)
852+
self.providers = providers
853+
for provider_name, _ in pairs(providers) do
854+
self.provider_cache[provider_name] = {}
855+
end
856+
end
857+
858+
--- @type CopilotChat.Client
859+
return Client()

lua/CopilotChat/config/contexts.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ local utils = require('CopilotChat.utils')
55
---@class CopilotChat.config.context
66
---@field description string?
77
---@field input fun(callback: fun(input: string?), source: CopilotChat.source)?
8-
---@field resolve fun(input: string?, source: CopilotChat.source):table<CopilotChat.context.embed>
8+
---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string, model: string):table<CopilotChat.context.embed>
99

1010
---@type table<string, CopilotChat.config.context>
1111
return {

lua/CopilotChat/config/providers.lua

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ local EDITOR_VERSION = 'Neovim/'
3131
.. '.'
3232
.. vim.version().patch
3333

34-
--- Get the github oauth cached token
35-
---@return string|nil
3634
local cached_github_token = nil
35+
36+
--- Get the github copilot oauth cached token (gu_ token)
37+
---@return string
3738
local function get_github_token()
3839
if cached_github_token then
3940
return cached_github_token
@@ -84,10 +85,11 @@ M.copilot = {
8485

8586
get_headers = function(token)
8687
return {
87-
['Authorization'] = 'Bearer ' .. token,
88-
['Editor-Version'] = EDITOR_VERSION,
89-
['Copilot-Integration-Id'] = 'vscode-chat',
90-
['Content-Type'] = 'application/json',
88+
['authorization'] = 'Bearer ' .. token,
89+
['editor-version'] = EDITOR_VERSION,
90+
['editor-plugin-version'] = 'CopilotChat.nvim/*',
91+
['copilot-integration-id'] = 'vscode-chat',
92+
['content-type'] = 'application/json',
9193
}
9294
end,
9395

lua/CopilotChat/context.lua

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
---@field outline string?
1515
---@field symbols table<string, CopilotChat.context.symbol>?
1616
---@field embedding table<number>?
17+
---@field score?
1718

1819
local async = require('plenary.async')
1920
local log = require('plenary.log')
21+
local client = require('CopilotChat.client')
2022
local notify = require('CopilotChat.notify')
2123
local utils = require('CopilotChat.utils')
2224
local file_cache = {}
@@ -100,7 +102,7 @@ local function data_ranked_by_relatedness(query, data, top_n)
100102
return vim.tbl_extend(
101103
'force',
102104
item,
103-
{ score = spatial_distance_cosine(item.embedding, query.embedding) }
105+
{ score = item.score or spatial_distance_cosine(item.embedding, query.embedding) }
104106
)
105107
end, data)
106108

@@ -583,12 +585,11 @@ function M.quickfix()
583585
end
584586

585587
--- Filter embeddings based on the query
586-
---@param client CopilotChat.Client
587588
---@param prompt string
588589
---@param model string
589590
---@param embeddings table<CopilotChat.context.embed>
590591
---@return table<CopilotChat.context.embed>
591-
function M.filter_embeddings(client, prompt, model, embeddings)
592+
function M.filter_embeddings(prompt, model, embeddings)
592593
-- If we dont need to embed anything, just return directly
593594
if #embeddings < MULTI_FILE_THRESHOLD then
594595
return embeddings

lua/CopilotChat/init.lua

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local async = require('plenary.async')
22
local log = require('plenary.log')
33
local context = require('CopilotChat.context')
4+
local client = require('CopilotChat.client')
45
local utils = require('CopilotChat.utils')
56

67
local M = {}
@@ -12,16 +13,13 @@ local WORD = '([^%s]+)'
1213
--- @field winnr number
1314

1415
--- @class CopilotChat.state
15-
--- @field client CopilotChat.Client?
1616
--- @field source CopilotChat.source?
1717
--- @field last_prompt string?
1818
--- @field last_response string?
1919
--- @field chat CopilotChat.ui.Chat?
2020
--- @field diff CopilotChat.ui.Diff?
2121
--- @field overlay CopilotChat.ui.Overlay?
2222
local state = {
23-
client = nil,
24-
2523
-- Current state tracking
2624
source = nil,
2725

@@ -221,9 +219,10 @@ end
221219

222220
--- Resolve the embeddings from the prompt.
223221
---@param prompt string
222+
---@param model string
224223
---@param config CopilotChat.config.shared
225224
---@return table<CopilotChat.context.embed>, string
226-
function M.resolve_embeddings(prompt, config)
225+
function M.resolve_embeddings(prompt, model, config)
227226
local contexts = {}
228227
local function parse_context(prompt_context)
229228
local split = vim.split(prompt_context, ':')
@@ -262,7 +261,9 @@ function M.resolve_embeddings(prompt, config)
262261
local embeddings = utils.ordered_map()
263262
for _, context_data in ipairs(contexts) do
264263
local context_value = M.config.contexts[context_data.name]
265-
for _, embedding in ipairs(context_value.resolve(context_data.input, state.source or {})) do
264+
for _, embedding in
265+
ipairs(context_value.resolve(context_data.input, state.source or {}, prompt, model))
266+
do
266267
if embedding then
267268
embeddings:set(embedding.filename, embedding)
268269
end
@@ -276,7 +277,7 @@ end
276277
---@param prompt string
277278
---@param config CopilotChat.config.shared
278279
function M.resolve_agent(prompt, config)
279-
local agents = vim.tbl_keys(state.client:list_agents())
280+
local agents = vim.tbl_keys(client:list_agents())
280281
local selected_agent = config.agent
281282
prompt = prompt:gsub('@' .. WORD, function(match)
282283
if vim.tbl_contains(agents, match) then
@@ -295,7 +296,7 @@ end
295296
function M.resolve_model(prompt, config)
296297
local models = vim.tbl_map(function(model)
297298
return model.id
298-
end, state.client:list_models())
299+
end, client:list_models())
299300

300301
local selected_model = config.model
301302
prompt = prompt:gsub('%$' .. WORD, function(match)
@@ -393,8 +394,8 @@ end
393394
---@param callback function(table)
394395
function M.complete_items(callback)
395396
async.run(function()
396-
local models = state.client:list_models()
397-
local agents = state.client:list_agents()
397+
local models = client:list_models()
398+
local agents = client:list_agents()
398399
local prompts_to_use = M.prompts()
399400
local items = {}
400401

@@ -536,7 +537,7 @@ end
536537
--- Select default Copilot GPT model.
537538
function M.select_model()
538539
async.run(function()
539-
local models = state.client:list_models()
540+
local models = client:list_models()
540541
local choices = vim.tbl_map(function(model)
541542
return {
542543
id = model.id,
@@ -567,7 +568,7 @@ end
567568
--- Select default Copilot agent.
568569
function M.select_agent()
569570
async.run(function()
570-
local agents = state.client:list_agents()
571+
local agents = client:list_agents()
571572
local choices = vim.tbl_map(function(agent)
572573
return {
573574
id = agent.id,
@@ -613,7 +614,7 @@ function M.ask(prompt, config)
613614
if not config.headless then
614615
if config.clear_chat_on_new_prompt then
615616
M.stop(true)
616-
elseif state.client:stop() then
617+
elseif client:stop() then
617618
finish()
618619
end
619620

@@ -642,13 +643,13 @@ function M.ask(prompt, config)
642643
local selection = M.get_selection(config)
643644

644645
local ok, err = pcall(async.run, function()
645-
local embeddings, prompt = M.resolve_embeddings(prompt, config)
646646
local selected_agent, prompt = M.resolve_agent(prompt, config)
647647
local selected_model, prompt = M.resolve_model(prompt, config)
648+
local embeddings, prompt = M.resolve_embeddings(prompt, selected_model, config)
648649

649650
local has_output = false
650651
local query_ok, filtered_embeddings =
651-
pcall(context.filter_embeddings, state.client, prompt, selected_model, embeddings)
652+
pcall(context.filter_embeddings, prompt, selected_model, embeddings)
652653

653654
if not query_ok then
654655
async.util.scheduler()
@@ -659,22 +660,21 @@ function M.ask(prompt, config)
659660
return
660661
end
661662

662-
local ask_ok, response, token_count, token_max_count =
663-
pcall(state.client.ask, state.client, prompt, {
664-
history = history,
665-
selection = selection,
666-
embeddings = filtered_embeddings,
667-
system_prompt = system_prompt,
668-
model = selected_model,
669-
agent = selected_agent,
670-
temperature = config.temperature,
671-
on_progress = vim.schedule_wrap(function(token)
672-
if not config.headless then
673-
state.chat:append(token)
674-
end
675-
has_output = true
676-
end),
677-
})
663+
local ask_ok, response, token_count, token_max_count = pcall(client.ask, client, prompt, {
664+
history = history,
665+
selection = selection,
666+
embeddings = filtered_embeddings,
667+
system_prompt = system_prompt,
668+
model = selected_model,
669+
agent = selected_agent,
670+
temperature = config.temperature,
671+
on_progress = vim.schedule_wrap(function(token)
672+
if not config.headless then
673+
state.chat:append(token)
674+
end
675+
has_output = true
676+
end),
677+
})
678678

679679
async.util.scheduler()
680680

@@ -716,12 +716,12 @@ end
716716
---@param reset boolean?
717717
function M.stop(reset)
718718
if reset then
719-
state.client:reset()
719+
client:reset()
720720
state.chat:clear()
721721
state.last_prompt = nil
722722
state.last_response = nil
723723
else
724-
state.client:stop()
724+
client:stop()
725725
end
726726

727727
finish(reset)
@@ -791,7 +791,7 @@ function M.load(name, history_path)
791791
},
792792
})
793793

794-
state.client:reset()
794+
client:reset()
795795
state.chat:clear()
796796
state.chat:load_history(history)
797797
log.info('Loaded history from ' .. history_path)
@@ -873,10 +873,9 @@ function M.setup(config)
873873
proxy = M.config.proxy,
874874
})
875875

876-
if state.client then
877-
state.client:stop()
878-
end
879-
state.client = require('CopilotChat.client')(M.config.providers)
876+
-- Load the providers
877+
client:stop()
878+
client:load_providers(M.config.providers)
880879

881880
if M.config.debug then
882881
M.log_level('debug')

0 commit comments

Comments
 (0)