3030 preserve_transaction_status_with_timeout
3131from aws_advanced_python_wrapper .utils .sliding_expiration_cache_container import \
3232 SlidingExpirationCacheContainer
33+ from aws_advanced_python_wrapper .utils .storage .storage_service import (
34+ StorageService , Topology )
3335
3436if TYPE_CHECKING :
3537 from aws_advanced_python_wrapper .driver_dialect import DriverDialect
5961
6062
6163class HostListProvider (Protocol ):
62- def refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
64+ def refresh (self , connection : Optional [Connection ] = None ) -> Topology :
6365 ...
6466
65- def force_refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
67+ def force_refresh (self , connection : Optional [Connection ] = None ) -> Topology :
6668 ...
6769
68- def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Tuple [ HostInfo , ...] :
70+ def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Topology :
6971 ...
7072
7173 def get_host_role (self , connection : Connection ) -> HostRole :
@@ -149,7 +151,6 @@ def is_static_host_list_provider(self) -> bool:
149151
150152
151153class RdsHostListProvider (DynamicHostListProvider , HostListProvider ):
152- _topology_cache : CacheMap [str , Tuple [HostInfo , ...]] = CacheMap ()
153154 # Maps cluster IDs to a boolean representing whether they are a primary cluster ID or not. A primary cluster ID is a
154155 # cluster ID that is equivalent to a cluster URL. Topology info is shared between RdsHostListProviders that have
155156 # the same cluster ID.
@@ -164,9 +165,9 @@ def __init__(self, host_list_provider_service: HostListProviderService, props: P
164165 self ._topology_utils = topology_utils
165166
166167 self ._rds_utils : RdsUtils = RdsUtils ()
167- self ._hosts : Tuple [ HostInfo , ...] = ()
168+ self ._hosts : Topology = ()
168169 self ._cluster_id : str = str (uuid .uuid4 ())
169- self ._initial_hosts : Tuple [ HostInfo , ...] = ()
170+ self ._initial_hosts : Topology = ()
170171 self ._rds_url_type : Optional [RdsUrlType ] = None
171172
172173 self ._is_primary_cluster_id : bool = False
@@ -182,7 +183,7 @@ def _initialize(self):
182183 if self ._is_initialized :
183184 return
184185
185- self ._initial_hosts : Tuple [ HostInfo , ...] = (self ._topology_utils .initial_host_info ,)
186+ self ._initial_hosts : Topology = (self ._topology_utils .initial_host_info ,)
186187 self ._host_list_provider_service .initial_connection_host_info = self ._topology_utils .initial_host_info
187188
188189 self ._rds_url_type : RdsUrlType = self ._rds_utils .identify_rds_type (self ._topology_utils .initial_host_info .host )
@@ -210,7 +211,10 @@ def _initialize(self):
210211 self ._is_initialized = True
211212
212213 def _get_suggested_cluster_id (self , url : str ) -> Optional [ClusterIdSuggestion ]:
213- for key , hosts in RdsHostListProvider ._topology_cache .get_dict ().items ():
214+ topology_cache = StorageService .get_all (Topology )
215+ if topology_cache is None :
216+ return None
217+ for key , hosts in topology_cache .get_dict ().items ():
214218 is_primary_cluster_id = \
215219 RdsHostListProvider ._is_primary_cluster_id_cache .get_with_default (
216220 key , False , self ._suggested_cluster_id_refresh_ns )
@@ -244,7 +248,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
244248 self ._cluster_id = suggested_primary_cluster_id
245249 self ._is_primary_cluster_id = True
246250
247- cached_hosts = RdsHostListProvider . _topology_cache . get (self ._cluster_id )
251+ cached_hosts = StorageService . get (Topology , self ._cluster_id )
248252 if not cached_hosts or force_update :
249253 if not conn :
250254 # Cannot fetch topology without a connection
@@ -255,7 +259,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
255259 driver_dialect = self ._host_list_provider_service .driver_dialect
256260 hosts = self .query_for_topology (conn , driver_dialect )
257261 if hosts is not None and len (hosts ) > 0 :
258- RdsHostListProvider . _topology_cache . put (self ._cluster_id , hosts , self . _refresh_rate_ns )
262+ StorageService . set (self ._cluster_id , hosts , Topology )
259263 if self ._is_primary_cluster_id and cached_hosts is None :
260264 # This cluster_id is primary and a new entry was just created in the cache. When this happens,
261265 # we check for non-primary cluster IDs associated with the same cluster so that the topology
@@ -270,14 +274,18 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
270274 else :
271275 return RdsHostListProvider .FetchTopologyResult (self ._initial_hosts , False )
272276
273- def query_for_topology (self , conn , driver_dialect ) -> Optional [Tuple [ HostInfo , ...] ]:
277+ def query_for_topology (self , conn , driver_dialect ) -> Optional [Topology ]:
274278 return self ._topology_utils .query_for_topology (conn , driver_dialect )
275279
276- def _suggest_cluster_id (self , primary_cluster_id_hosts : Tuple [ HostInfo , ...] ):
280+ def _suggest_cluster_id (self , primary_cluster_id_hosts : Topology ):
277281 if not primary_cluster_id_hosts :
278- return
282+ return None
279283
280- for cluster_id , hosts in RdsHostListProvider ._topology_cache .get_dict ().items ():
284+ topology_cache = StorageService .get_all (Topology )
285+ if topology_cache is None :
286+ return None
287+
288+ for cluster_id , hosts in topology_cache .get_dict ().items ():
281289 is_primary_cluster = RdsHostListProvider ._is_primary_cluster_id_cache .get_with_default (
282290 cluster_id , False , self ._suggested_cluster_id_refresh_ns )
283291 suggested_primary_cluster_id = RdsHostListProvider ._cluster_ids_to_update .get (cluster_id )
@@ -293,8 +301,9 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
293301 RdsHostListProvider ._cluster_ids_to_update .put (
294302 cluster_id , self ._cluster_id , self ._suggested_cluster_id_refresh_ns )
295303 break
304+ return None
296305
297- def refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
306+ def refresh (self , connection : Optional [Connection ] = None ) -> Topology :
298307 """
299308 Get topology information for the database cluster.
300309 This method executes a database query if there is no information for the cluster in the cache, or if the cached topology is outdated.
@@ -311,7 +320,7 @@ def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ..
311320 self ._hosts = topology .hosts
312321 return tuple (self ._hosts )
313322
314- def force_refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
323+ def force_refresh (self , connection : Optional [Connection ] = None ) -> Topology :
315324 """
316325 Execute a database query to retrieve information for the current cluster topology. Any cached topology information will be ignored.
317326
@@ -327,7 +336,7 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn
327336 self ._hosts = topology .hosts
328337 return tuple (self ._hosts )
329338
330- def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Tuple [ HostInfo , ...] :
339+ def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Topology :
331340 raise AwsWrapperError (
332341 Messages .get_formatted ("HostListProvider.ForceMonitoringRefreshUnsupported" , "RdsHostListProvider" ))
333342
@@ -385,7 +394,7 @@ class ClusterIdSuggestion:
385394
386395 @dataclass ()
387396 class FetchTopologyResult :
388- hosts : Tuple [ HostInfo , ...]
397+ hosts : Topology
389398 is_cached_data : bool
390399
391400
@@ -394,7 +403,7 @@ class ConnectionStringHostListProvider(StaticHostListProvider):
394403 def __init__ (self , host_list_provider_service : HostListProviderService , props : Properties ):
395404 self ._host_list_provider_service : HostListProviderService = host_list_provider_service
396405 self ._props : Properties = props
397- self ._hosts : Tuple [ HostInfo , ...] = ()
406+ self ._hosts : Topology = ()
398407 self ._is_initialized : bool = False
399408 self ._initial_host_info : Optional [HostInfo ] = None
400409
@@ -412,15 +421,15 @@ def _initialize(self):
412421 self ._host_list_provider_service .initial_connection_host_info = self ._initial_host_info
413422 self ._is_initialized = True
414423
415- def refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
424+ def refresh (self , connection : Optional [Connection ] = None ) -> Topology :
416425 self ._initialize ()
417426 return tuple (self ._hosts )
418427
419- def force_refresh (self , connection : Optional [Connection ] = None ) -> Tuple [ HostInfo , ...] :
428+ def force_refresh (self , connection : Optional [Connection ] = None ) -> Topology :
420429 self ._initialize ()
421430 return tuple (self ._hosts )
422431
423- def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Tuple [ HostInfo , ...] :
432+ def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Topology :
424433 raise AwsWrapperError (
425434 Messages .get_formatted ("HostListProvider.ForceMonitoringRefreshUnsupported" , "ConnectionStringHostListProvider" ))
426435
@@ -499,7 +508,7 @@ def query_for_topology(
499508 self ,
500509 conn : Connection ,
501510 driver_dialect : DriverDialect ,
502- ) -> Optional [Tuple [ HostInfo , ...] ]:
511+ ) -> Optional [Topology ]:
503512 """
504513 Query the database for topology information.
505514
@@ -513,7 +522,7 @@ def query_for_topology(
513522 return x
514523
515524 @abstractmethod
516- def _query_for_topology (self , conn : Connection ) -> Optional [Tuple [ HostInfo , ...] ]:
525+ def _query_for_topology (self , conn : Connection ) -> Optional [Topology ]:
517526 pass
518527
519528 def _create_host (self , record : Tuple ) -> HostInfo :
@@ -628,7 +637,7 @@ class AuroraTopologyUtils(TopologyUtils):
628637
629638 _executor_name : ClassVar [str ] = "AuroraTopologyUtils"
630639
631- def _query_for_topology (self , conn : Connection ) -> Optional [Tuple [ HostInfo , ...] ]:
640+ def _query_for_topology (self , conn : Connection ) -> Optional [Topology ]:
632641 """
633642 Query the database for topology information.
634643
@@ -643,7 +652,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
643652 except ProgrammingError as e :
644653 raise AwsWrapperError (Messages .get ("RdsHostListProvider.InvalidQuery" ), e ) from e
645654
646- def _process_query_results (self , cursor : Cursor ) -> Tuple [ HostInfo , ...] :
655+ def _process_query_results (self , cursor : Cursor ) -> Topology :
647656 """
648657 Form a list of hosts from the results of the topology query.
649658 :param cursor: The Cursor object containing a reference to the results of the topology query.
@@ -692,7 +701,7 @@ def __init__(
692701 self ._writer_host_query = writer_host_query
693702 self ._writer_host_column_index = writer_host_column_index
694703
695- def _query_for_topology (self , conn : Connection ) -> Optional [Tuple [ HostInfo , ...] ]:
704+ def _query_for_topology (self , conn : Connection ) -> Optional [Topology ]:
696705 try :
697706 with closing (conn .cursor ()) as cursor :
698707 cursor .execute (self ._writer_host_query )
@@ -709,7 +718,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
709718 except ProgrammingError as e :
710719 raise AwsWrapperError (Messages .get ("RdsHostListProvider.InvalidQuery" ), e ) from e
711720
712- def _process_multi_az_query_results (self , cursor : Cursor , writer_id : str ) -> Tuple [ HostInfo , ...] :
721+ def _process_multi_az_query_results (self , cursor : Cursor , writer_id : str ) -> Topology :
713722 hosts_dict = {}
714723 for record in cursor :
715724 host : HostInfo = self ._create_multi_az_host (record , writer_id )
@@ -789,7 +798,7 @@ def _get_monitor(self) -> Optional[ClusterTopologyMonitor]:
789798 self ._high_refresh_rate_ns
790799 ), MonitoringRdsHostListProvider ._MONITOR_CLEANUP_NANO )
791800
792- def query_for_topology (self , connection : Connection , driver_dialect ) -> Optional [Tuple [ HostInfo , ...] ]:
801+ def query_for_topology (self , connection : Connection , driver_dialect ) -> Optional [Topology ]:
793802 monitor = self ._get_monitor ()
794803
795804 if monitor is None :
@@ -800,7 +809,7 @@ def query_for_topology(self, connection: Connection, driver_dialect) -> Optional
800809 except TimeoutError :
801810 return None
802811
803- def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Tuple [ HostInfo , ...] :
812+ def force_monitoring_refresh (self , should_verify_writer : bool , timeout_sec : int ) -> Topology :
804813 monitor = self ._get_monitor ()
805814
806815 if monitor is None :
0 commit comments