11local async = require (' plenary.async' )
22local log = require (' plenary.log' )
33local context = require (' CopilotChat.context' )
4+ local client = require (' CopilotChat.client' )
45local utils = require (' CopilotChat.utils' )
56
67local M = {}
@@ -12,16 +13,13 @@ local WORD = '([^%s]+)'
1213--- @field winnr number
1314
1415--- @class CopilotChat.state
15- --- @field client CopilotChat.Client ?
1616--- @field source CopilotChat.source ?
1717--- @field last_prompt string ?
1818--- @field last_response string ?
1919--- @field chat CopilotChat.ui.Chat ?
2020--- @field diff CopilotChat.ui.Diff ?
2121--- @field overlay CopilotChat.ui.Overlay ?
2222local state = {
23- client = nil ,
24-
2523 -- Current state tracking
2624 source = nil ,
2725
221219
222220--- Resolve the embeddings from the prompt.
223221--- @param prompt string
222+ --- @param model string
224223--- @param config CopilotChat.config.shared
225224--- @return table<CopilotChat.context.embed> , string
226- function M .resolve_embeddings (prompt , config )
225+ function M .resolve_embeddings (prompt , model , config )
227226 local contexts = {}
228227 local function parse_context (prompt_context )
229228 local split = vim .split (prompt_context , ' :' )
@@ -262,7 +261,9 @@ function M.resolve_embeddings(prompt, config)
262261 local embeddings = utils .ordered_map ()
263262 for _ , context_data in ipairs (contexts ) do
264263 local context_value = M .config .contexts [context_data .name ]
265- for _ , embedding in ipairs (context_value .resolve (context_data .input , state .source or {})) do
264+ for _ , embedding in
265+ ipairs (context_value .resolve (context_data .input , state .source or {}, prompt , model ))
266+ do
266267 if embedding then
267268 embeddings :set (embedding .filename , embedding )
268269 end
276277--- @param prompt string
277278--- @param config CopilotChat.config.shared
278279function M .resolve_agent (prompt , config )
279- local agents = vim .tbl_keys (state . client :list_agents ())
280+ local agents = vim .tbl_keys (client :list_agents ())
280281 local selected_agent = config .agent
281282 prompt = prompt :gsub (' @' .. WORD , function (match )
282283 if vim .tbl_contains (agents , match ) then
295296function M .resolve_model (prompt , config )
296297 local models = vim .tbl_map (function (model )
297298 return model .id
298- end , state . client :list_models ())
299+ end , client :list_models ())
299300
300301 local selected_model = config .model
301302 prompt = prompt :gsub (' %$' .. WORD , function (match )
393394--- @param callback function (table )
394395function M .complete_items (callback )
395396 async .run (function ()
396- local models = state . client :list_models ()
397- local agents = state . client :list_agents ()
397+ local models = client :list_models ()
398+ local agents = client :list_agents ()
398399 local prompts_to_use = M .prompts ()
399400 local items = {}
400401
536537--- Select default Copilot GPT model.
537538function M .select_model ()
538539 async .run (function ()
539- local models = state . client :list_models ()
540+ local models = client :list_models ()
540541 local choices = vim .tbl_map (function (model )
541542 return {
542543 id = model .id ,
567568--- Select default Copilot agent.
568569function M .select_agent ()
569570 async .run (function ()
570- local agents = state . client :list_agents ()
571+ local agents = client :list_agents ()
571572 local choices = vim .tbl_map (function (agent )
572573 return {
573574 id = agent .id ,
@@ -613,7 +614,7 @@ function M.ask(prompt, config)
613614 if not config .headless then
614615 if config .clear_chat_on_new_prompt then
615616 M .stop (true )
616- elseif state . client :stop () then
617+ elseif client :stop () then
617618 finish ()
618619 end
619620
@@ -642,13 +643,13 @@ function M.ask(prompt, config)
642643 local selection = M .get_selection (config )
643644
644645 local ok , err = pcall (async .run , function ()
645- local embeddings , prompt = M .resolve_embeddings (prompt , config )
646646 local selected_agent , prompt = M .resolve_agent (prompt , config )
647647 local selected_model , prompt = M .resolve_model (prompt , config )
648+ local embeddings , prompt = M .resolve_embeddings (prompt , selected_model , config )
648649
649650 local has_output = false
650651 local query_ok , filtered_embeddings =
651- pcall (context .filter_embeddings , state . client , prompt , selected_model , embeddings )
652+ pcall (context .filter_embeddings , prompt , selected_model , embeddings )
652653
653654 if not query_ok then
654655 async .util .scheduler ()
@@ -659,22 +660,21 @@ function M.ask(prompt, config)
659660 return
660661 end
661662
662- local ask_ok , response , token_count , token_max_count =
663- pcall (state .client .ask , state .client , prompt , {
664- history = history ,
665- selection = selection ,
666- embeddings = filtered_embeddings ,
667- system_prompt = system_prompt ,
668- model = selected_model ,
669- agent = selected_agent ,
670- temperature = config .temperature ,
671- on_progress = vim .schedule_wrap (function (token )
672- if not config .headless then
673- state .chat :append (token )
674- end
675- has_output = true
676- end ),
677- })
663+ local ask_ok , response , token_count , token_max_count = pcall (client .ask , client , prompt , {
664+ history = history ,
665+ selection = selection ,
666+ embeddings = filtered_embeddings ,
667+ system_prompt = system_prompt ,
668+ model = selected_model ,
669+ agent = selected_agent ,
670+ temperature = config .temperature ,
671+ on_progress = vim .schedule_wrap (function (token )
672+ if not config .headless then
673+ state .chat :append (token )
674+ end
675+ has_output = true
676+ end ),
677+ })
678678
679679 async .util .scheduler ()
680680
@@ -716,12 +716,12 @@ end
716716--- @param reset boolean ?
717717function M .stop (reset )
718718 if reset then
719- state . client :reset ()
719+ client :reset ()
720720 state .chat :clear ()
721721 state .last_prompt = nil
722722 state .last_response = nil
723723 else
724- state . client :stop ()
724+ client :stop ()
725725 end
726726
727727 finish (reset )
@@ -791,7 +791,7 @@ function M.load(name, history_path)
791791 },
792792 })
793793
794- state . client :reset ()
794+ client :reset ()
795795 state .chat :clear ()
796796 state .chat :load_history (history )
797797 log .info (' Loaded history from ' .. history_path )
@@ -873,10 +873,9 @@ function M.setup(config)
873873 proxy = M .config .proxy ,
874874 })
875875
876- if state .client then
877- state .client :stop ()
878- end
879- state .client = require (' CopilotChat.client' )(M .config .providers )
876+ -- Load the providers
877+ client :stop ()
878+ client :load_providers (M .config .providers )
880879
881880 if M .config .debug then
882881 M .log_level (' debug' )
0 commit comments