1414
1515from __future__ import annotations
1616
17+ import threading
1718from threading import Thread
18- from typing import (TYPE_CHECKING , Any , Callable , Dict , FrozenSet , Optional ,
19- Set , Tuple )
19+ from time import perf_counter_ns
20+ from typing import (TYPE_CHECKING , Any , Callable , ClassVar , Dict , FrozenSet ,
21+ Optional , Set )
22+
23+ from aws_advanced_python_wrapper .utils .notifications import HostEvent
24+ from aws_advanced_python_wrapper .utils .utils import Utils
2025
2126if TYPE_CHECKING :
2227 from aws_advanced_python_wrapper .driver_dialect import DriverDialect
2328 from aws_advanced_python_wrapper .plugin_service import PluginService
29+ from aws_advanced_python_wrapper .hostinfo import HostInfo
2430 from aws_advanced_python_wrapper .pep249 import Connection
2531
2632 from aws_advanced_python_wrapper .utils .rds_url_type import RdsUrlType
2935from _weakrefset import WeakSet
3036
3137from aws_advanced_python_wrapper .errors import FailoverError
32- from aws_advanced_python_wrapper .hostinfo import HostInfo , HostRole
3338from aws_advanced_python_wrapper .pep249_methods import DbApiMethod
3439from aws_advanced_python_wrapper .plugin import Plugin , PluginFactory
3540from aws_advanced_python_wrapper .utils .log import Logger
3944
4045
4146class OpenedConnectionTracker :
42- _opened_connections : Dict [str , WeakSet ] = {}
43- _rds_utils = RdsUtils ()
47+ _opened_connections : ClassVar [Dict [str , WeakSet ]] = {}
48+ _lock : ClassVar [threading .Lock ] = threading .Lock ()
49+ _rds_utils : ClassVar [RdsUtils ] = RdsUtils ()
50+ _prune_thread : ClassVar [Optional [Thread ]] = None
51+ _prune_thread_started : ClassVar [bool ] = False
52+ _shutdown_event : ClassVar [threading .Event ] = threading .Event ()
53+ _safe_to_check_closed_classes : ClassVar [Set [str ]] = {"psycopg" }
54+ _default_sleep_time : ClassVar [int ] = 30
55+
56+ @classmethod
57+ def _start_prune_thread (cls ):
58+ with cls ._lock :
59+ if not cls ._prune_thread_started :
60+ cls ._prune_thread_started = True
61+ cls ._prune_thread = Thread (daemon = True , target = cls ._prune_connections_loop )
62+ cls ._prune_thread .start ()
63+
64+ @classmethod
65+ def release_resources (cls ):
66+ cls ._shutdown_event .set ()
67+ with cls ._lock :
68+ thread_to_join = cls ._prune_thread
69+ if thread_to_join is not None :
70+ thread_to_join .join ()
71+ with cls ._lock :
72+ cls ._opened_connections .clear ()
73+
74+ @classmethod
75+ def _prune_connections_loop (cls ):
76+ while not cls ._shutdown_event .is_set ():
77+ try :
78+ cls ._prune_connections ()
79+ if cls ._shutdown_event .wait (timeout = cls ._default_sleep_time ):
80+ break
81+ except Exception :
82+ pass
83+
84+ @classmethod
85+ def _prune_connections (cls ):
86+ with cls ._lock :
87+ opened_connections = list (cls ._opened_connections .items ())
88+
89+ to_remove_by_host = {}
90+ for host , conn_set in opened_connections :
91+ to_remove = []
92+ for conn in list (conn_set ):
93+ if conn is None :
94+ to_remove .append (conn )
95+ else :
96+ try :
97+ # The following classes do not check connection validity via a DB server call
98+ # so it is safe to check whether connection is already closed.
99+ if any (safe_class in conn .__module__ for safe_class in cls ._safe_to_check_closed_classes ) and conn .is_closed ():
100+ to_remove .append (conn )
101+ except Exception :
102+ pass
103+
104+ if to_remove :
105+ to_remove_by_host [host ] = (conn_set , to_remove )
106+
107+ with cls ._lock :
108+ for host , (conn_set , to_remove ) in to_remove_by_host .items ():
109+ for conn in to_remove :
110+ conn_set .discard (conn )
111+
112+ # Remove empty connection sets
113+ if not conn_set and host in cls ._opened_connections :
114+ del cls ._opened_connections [host ]
44115
45116 def populate_opened_connection_set (self , host_info : HostInfo , conn : Connection ):
46117 """
@@ -56,8 +127,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
56127 self ._track_connection (host_info .as_alias (), conn )
57128 return
58129
59- instance_endpoint : Optional [str ] = next (( alias for alias in aliases if self . _rds_utils . is_rds_instance ( self . _rds_utils . remove_port ( alias ))),
60- None )
130+ instance_endpoint : Optional [str ] = next (
131+ ( alias for alias in aliases if self . _rds_utils . is_rds_instance ( self . _rds_utils . remove_port ( alias ))), None )
61132 if not instance_endpoint :
62133 logger .debug ("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet" )
63134 return
@@ -73,7 +144,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
73144 """
74145
75146 if host_info :
76- self .invalidate_all_connections (host = frozenset (host_info .as_alias ()))
147+ self .invalidate_all_connections (host = frozenset ([ host_info .as_alias ()] ))
77148 self .invalidate_all_connections (host = host_info .as_aliases ())
78149 return
79150
@@ -89,27 +160,42 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
89160 if not instance_endpoint :
90161 return
91162
92- connection_set : Optional [WeakSet ] = self ._opened_connections .get (instance_endpoint )
93- if connection_set is not None :
163+ with self ._lock :
164+ connection_set : Optional [WeakSet ] = self ._opened_connections .get (instance_endpoint )
165+ connections_list = list (connection_set ) if connection_set is not None else None
166+
167+ if connections_list is not None :
94168 self ._log_connection_set (instance_endpoint , connection_set )
95- self ._invalidate_connections (connection_set )
169+ self ._invalidate_connections (connections_list )
96170
97- def _track_connection (self , instance_endpoint : str , conn : Connection ):
98- connection_set : Optional [ WeakSet ] = self . _opened_connections . get ( instance_endpoint )
99- if connection_set is None :
100- connection_set = WeakSet ()
101- connection_set . add ( conn )
102- self . _opened_connections [ instance_endpoint ] = connection_set
171+ def remove_connection_tracking (self , host_info : HostInfo , connection : Connection | None ):
172+ if not connection :
173+ return
174+
175+ if self . _rds_utils . is_rds_instance ( host_info . host ):
176+ host = host_info . as_alias ()
103177 else :
104- connection_set .add (conn )
178+ host = next ((alias for alias in host_info .as_aliases ()
179+ if self ._rds_utils .is_rds_instance (self ._rds_utils .remove_port (alias ))), "" )
180+
181+ if not host :
182+ return
183+
184+ with self ._lock :
185+ connection_set = self ._opened_connections .get (host )
186+ if connection_set :
187+ connection_set .discard (connection )
105188
189+ def _track_connection (self , instance_endpoint : str , conn : Connection ):
190+ with self ._lock :
191+ connection_set = self ._opened_connections .setdefault (instance_endpoint , WeakSet ())
192+ connection_set .add (conn )
193+ self ._start_prune_thread ()
106194 self .log_opened_connections ()
107195
108196 @staticmethod
109- def _task (connection_set : WeakSet ):
110- while connection_set is not None and len (connection_set ) > 0 :
111- conn_reference = connection_set .pop ()
112-
197+ def _task (connections_list : list ):
198+ for conn_reference in connections_list :
113199 if conn_reference is None :
114200 continue
115201
@@ -119,37 +205,38 @@ def _task(connection_set: WeakSet):
119205 # Swallow this exception, current connection should be useless anyway
120206 pass
121207
122- def _invalidate_connections (self , connection_set : WeakSet ):
208+ def _invalidate_connections (self , connections_list : list ):
123209 invalidate_connection_thread : Thread = Thread (daemon = True , target = self ._task ,
124- args = [connection_set ]) # type: ignore
210+ args = [connections_list ]) # type: ignore
125211 invalidate_connection_thread .start ()
126212
127213 def log_opened_connections (self ):
128- msg = ""
129- for key , conn_set in self ._opened_connections .items ():
130- conn = ""
131- for item in list (conn_set ):
132- conn += f"\n \t \t { item } "
214+ with self ._lock :
215+ opened_connections = [(key , list (conn_set )) for key , conn_set in self ._opened_connections .items ()]
133216
134- msg += f"\t [{ key } : { conn } ]"
217+ msg_parts = []
218+ for key , conn_list in opened_connections :
219+ conn_parts = [f"\n \t \t { item } " for item in conn_list ]
220+ conn = "" .join (conn_parts )
221+ msg_parts .append (f"\t [{ key } : { conn } ]" )
135222
223+ msg = "" .join (msg_parts )
136224 return logger .debug ("OpenedConnectionTracker.OpenedConnectionsTracked" , msg )
137225
138226 def _log_connection_set (self , host : str , conn_set : Optional [WeakSet ]):
139227 if conn_set is None or len (conn_set ) == 0 :
140228 return
141229
142- conn = ""
143- for item in list (conn_set ):
144- conn += f"\n \t \t { item } "
145-
230+ conn_parts = [f"\n \t \t { item } " for item in list (conn_set )]
231+ conn = "" .join (conn_parts )
146232 msg = host + f"[{ conn } \n ]"
147233 logger .debug ("OpenedConnectionTracker.InvalidatingConnections" , msg )
148234
149235
150236class AuroraConnectionTrackerPlugin (Plugin ):
151- _current_writer : Optional [HostInfo ] = None
152- _need_update_current_writer : bool = False
237+ _host_list_refresh_end_time_nano : ClassVar [int ] = 0
238+ _refresh_lock : ClassVar [threading .Lock ] = threading .Lock ()
239+ _TOPOLOGY_CHANGES_EXPECTED_TIME_NANO : ClassVar [int ] = 3 * 60 * 1_000_000_000 # 3 minutes
153240
154241 @property
155242 def subscribed_methods (self ) -> Set [str ]:
@@ -164,6 +251,8 @@ def __init__(self,
164251 self ._props = props
165252 self ._rds_utils = rds_utils
166253 self ._tracker = tracker
254+ self ._current_writer : Optional [HostInfo ] = None
255+ self ._need_update_current_writer : bool = False
167256 self ._subscribed_methods : Set [str ] = {DbApiMethod .CONNECT .method_name ,
168257 DbApiMethod .CONNECTION_CLOSE .method_name ,
169258 DbApiMethod .CONNECT .method_name ,
@@ -192,26 +281,67 @@ def connect(
192281 return conn
193282
194283 def execute (self , target : object , method_name : str , execute_func : Callable , * args : Any , ** kwargs : Any ) -> Any :
284+ current_host = self ._plugin_service .current_host_info
195285 if self ._current_writer is None or self ._need_update_current_writer :
196- self ._current_writer = self . _get_writer (self ._plugin_service .all_hosts )
286+ self ._current_writer = Utils . get_writer (self ._plugin_service .all_hosts )
197287 self ._need_update_current_writer = False
198288
199289 try :
200- return execute_func ()
290+ if not method_name == DbApiMethod .CONNECTION_CLOSE .method_name :
291+ need_refresh_host_lists = False
292+ with AuroraConnectionTrackerPlugin ._refresh_lock :
293+ local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin ._host_list_refresh_end_time_nano
294+ if local_host_list_refresh_end_time_nano > 0 :
295+ if local_host_list_refresh_end_time_nano > perf_counter_ns ():
296+ # The time specified in hostListRefreshThresholdTimeNano isn't yet reached.
297+ # Need to continue to refresh host list.
298+ need_refresh_host_lists = True
299+ else :
300+ # The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes
301+ # of host list.
302+ AuroraConnectionTrackerPlugin ._host_list_refresh_end_time_nano = 0
303+
304+ if self ._need_update_current_writer or need_refresh_host_lists :
305+ # Calling this method may effectively close/abort a current connection
306+ self ._check_writer_changed (need_refresh_host_lists )
307+
308+ result = execute_func ()
309+ if method_name == DbApiMethod .CONNECTION_CLOSE .method_name :
310+ self ._tracker .remove_connection_tracking (current_host , self ._plugin_service .current_connection )
311+ return result
201312
202313 except Exception as e :
203- # Check that e is a FailoverError and that the writer has changed
204- if isinstance (e , FailoverError ) and self ._get_writer (self ._plugin_service .all_hosts ) != self ._current_writer :
205- self ._tracker .invalidate_all_connections (host_info = self ._current_writer )
206- self ._tracker .log_opened_connections ()
207- self ._need_update_current_writer = True
208- raise e
314+ if isinstance (e , FailoverError ):
315+ with AuroraConnectionTrackerPlugin ._refresh_lock :
316+ AuroraConnectionTrackerPlugin ._host_list_refresh_end_time_nano = (
317+ perf_counter_ns () + AuroraConnectionTrackerPlugin ._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO )
318+ # Calling this method may effectively close/abort a current connection
319+ self ._check_writer_changed (True )
320+ raise
321+
322+ def _check_writer_changed (self , need_refresh_host_lists : bool ):
323+ if need_refresh_host_lists :
324+ self ._plugin_service .refresh_host_list ()
325+
326+ host_info_after_failover = Utils .get_writer (self ._plugin_service .all_hosts )
327+ if host_info_after_failover is None :
328+ return
329+
330+ if self ._current_writer is None :
331+ self ._current_writer = host_info_after_failover
332+ self ._need_update_current_writer = False
333+ elif not self ._current_writer .get_host_and_port () == host_info_after_failover .get_host_and_port ():
334+ self ._tracker .invalidate_all_connections (host_info = self ._current_writer )
335+ self ._tracker .log_opened_connections ()
336+ self ._current_writer = host_info_after_failover
337+ self ._need_update_current_writer = False
209338
210- def _get_writer (self , hosts : Tuple [HostInfo , ...]) -> Optional [HostInfo ]:
211- for host in hosts :
212- if host .role == HostRole .WRITER :
213- return host
214- return None
339+ def notify_host_list_changed (self , changes : Dict [str , Set [HostEvent ]]):
340+ for node , node_changes in changes .items ():
341+ if HostEvent .CONVERTED_TO_READER in node_changes :
342+ self ._tracker .invalidate_all_connections (host = frozenset ([node ]))
343+ if HostEvent .CONVERTED_TO_WRITER in node_changes :
344+ self ._need_update_current_writer = True
215345
216346
217347class AuroraConnectionTrackerPluginFactory (PluginFactory ):
0 commit comments