Skip to content

Commit 7f80922

Browse files
authored
chore: simple storage service (#1097)
1 parent 6d5216d commit 7f80922

10 files changed

Lines changed: 172 additions & 72 deletions

File tree

aws_advanced_python_wrapper/cluster_topology_monitor.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
import time
1919
from abc import ABC, abstractmethod
2020
from concurrent.futures import ThreadPoolExecutor
21-
from typing import TYPE_CHECKING, Dict, Optional, Tuple
21+
from typing import TYPE_CHECKING, Dict, Optional
2222

2323
from aws_advanced_python_wrapper.host_availability import HostAvailability
2424
from aws_advanced_python_wrapper.hostinfo import HostInfo
2525
from aws_advanced_python_wrapper.utils.atomic import AtomicReference
26-
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
2726
from aws_advanced_python_wrapper.utils.messages import Messages
2827
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
28+
from aws_advanced_python_wrapper.utils.storage.storage_service import (
29+
StorageService, Topology)
2930
from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \
3031
ThreadSafeConnectionHolder
3132
from aws_advanced_python_wrapper.utils.utils import LogUtils
@@ -46,11 +47,11 @@
4647

4748
class ClusterTopologyMonitor(ABC):
4849
@abstractmethod
49-
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
50+
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
5051
pass
5152

5253
@abstractmethod
53-
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]:
54+
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology:
5455
pass
5556

5657
@abstractmethod
@@ -75,8 +76,6 @@ class ClusterTopologyMonitorImpl(ClusterTopologyMonitor):
7576
INITIAL_BACKOFF_MS = 100
7677
MAX_BACKOFF_MS = 10000
7778

78-
_topology_map: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap()
79-
8079
def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, cluster_id: str,
8180
initial_host_info: HostInfo, properties: Properties, instance_template: HostInfo,
8281
refresh_rate_nano: int, high_refresh_rate_nano: int):
@@ -103,7 +102,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
103102
self._host_threads_writer_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
104103
self._host_threads_writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None)
105104
self._host_threads_reader_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
106-
self._host_threads_latest_topology: AtomicReference[Optional[Tuple[HostInfo, ...]]] = AtomicReference(None)
105+
self._host_threads_latest_topology: AtomicReference[Optional[Topology]] = AtomicReference(None)
107106

108107
self._is_verified_writer_connection = False
109108
self._high_refresh_rate_end_time_nano = 0
@@ -118,7 +117,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
118117

119118
self._start_monitoring()
120119

121-
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
120+
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
122121
current_time_nano = time.time_ns()
123122
if (self._ignore_new_topology_requests_end_time_nano > 0 and
124123
current_time_nano < self._ignore_new_topology_requests_end_time_nano):
@@ -134,12 +133,12 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H
134133
result = self._wait_till_topology_gets_updated(timeout_sec)
135134
return result
136135

137-
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]:
136+
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology:
138137
if self._is_verified_writer_connection:
139138
return self._wait_till_topology_gets_updated(timeout_sec)
140139
return self._fetch_topology_and_update_cache(connection)
141140

142-
def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo, ...]:
141+
def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Topology:
143142
current_hosts = self._get_stored_hosts()
144143

145144
self._request_to_update_topology.set()
@@ -162,8 +161,8 @@ def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo,
162161
"ClusterTopologyMonitorImpl.TopologyNotUpdated",
163162
self._cluster_id, timeout_sec * 1000))
164163

165-
def _get_stored_hosts(self) -> Tuple[HostInfo, ...]:
166-
hosts = ClusterTopologyMonitorImpl._topology_map.get(self._cluster_id)
164+
def _get_stored_hosts(self) -> Topology:
165+
hosts = StorageService.get(Topology, self._cluster_id)
167166
if hosts is None:
168167
return ()
169168
return hosts
@@ -296,7 +295,7 @@ def _is_in_panic_mode(self) -> bool:
296295
def _get_host_monitor(self, host_info: HostInfo, writer_host_info: Optional[HostInfo]):
297296
return HostMonitor(self, host_info, writer_host_info)
298297

299-
def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
298+
def _open_any_connection_and_update_topology(self) -> Topology:
300299
writer_verified_by_this_thread = False
301300
if self._monitoring_connection.get() is None:
302301
# Try to connect to the initial host first
@@ -409,7 +408,7 @@ def _delay(self, use_high_refresh_rate: bool) -> None:
409408
while not self._request_to_update_topology.is_set() and time.time() < end_time and not self._stop.is_set():
410409
time.sleep(0.05)
411410

412-
def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Tuple[HostInfo, ...]:
411+
def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Topology:
413412
if connection is None:
414413
return ()
415414

@@ -423,7 +422,7 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) ->
423422
logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex)
424423
return ()
425424

426-
def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
425+
def _fetch_topology_and_update_cache_safe(self) -> Topology:
427426
"""
428427
Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions.
429428
The lock is held during the entire query operation.
@@ -433,16 +432,14 @@ def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
433432
)
434433
return result if result is not None else ()
435434

436-
def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]:
435+
def _query_for_topology(self, connection: Connection) -> Topology:
437436
hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect)
438437
if hosts is not None:
439438
return hosts
440439
return ()
441440

442-
def _update_topology_cache(self, hosts: Tuple[HostInfo, ...]) -> None:
443-
ClusterTopologyMonitorImpl._topology_map.put(
444-
self._cluster_id, hosts, ClusterTopologyMonitorImpl.TOPOLOGY_CACHE_EXPIRATION_NANO)
445-
441+
def _update_topology_cache(self, hosts: Topology) -> None:
442+
StorageService.set(self._cluster_id, hosts, Topology)
446443
# Notify waiting threads
447444
self._request_to_update_topology.clear()
448445
self._topology_updated.set()

aws_advanced_python_wrapper/host_list_provider.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
preserve_transaction_status_with_timeout
3131
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
3232
SlidingExpirationCacheContainer
33+
from aws_advanced_python_wrapper.utils.storage.storage_service import (
34+
StorageService, Topology)
3335

3436
if TYPE_CHECKING:
3537
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
@@ -59,13 +61,13 @@
5961

6062

6163
class HostListProvider(Protocol):
62-
def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
64+
def refresh(self, connection: Optional[Connection] = None) -> Topology:
6365
...
6466

65-
def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
67+
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
6668
...
6769

68-
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
70+
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
6971
...
7072

7173
def get_host_role(self, connection: Connection) -> HostRole:
@@ -149,7 +151,6 @@ def is_static_host_list_provider(self) -> bool:
149151

150152

151153
class RdsHostListProvider(DynamicHostListProvider, HostListProvider):
152-
_topology_cache: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap()
153154
# Maps cluster IDs to a boolean representing whether they are a primary cluster ID or not. A primary cluster ID is a
154155
# cluster ID that is equivalent to a cluster URL. Topology info is shared between RdsHostListProviders that have
155156
# the same cluster ID.
@@ -164,9 +165,9 @@ def __init__(self, host_list_provider_service: HostListProviderService, props: P
164165
self._topology_utils = topology_utils
165166

166167
self._rds_utils: RdsUtils = RdsUtils()
167-
self._hosts: Tuple[HostInfo, ...] = ()
168+
self._hosts: Topology = ()
168169
self._cluster_id: str = str(uuid.uuid4())
169-
self._initial_hosts: Tuple[HostInfo, ...] = ()
170+
self._initial_hosts: Topology = ()
170171
self._rds_url_type: Optional[RdsUrlType] = None
171172

172173
self._is_primary_cluster_id: bool = False
@@ -182,7 +183,7 @@ def _initialize(self):
182183
if self._is_initialized:
183184
return
184185

185-
self._initial_hosts: Tuple[HostInfo, ...] = (self._topology_utils.initial_host_info,)
186+
self._initial_hosts: Topology = (self._topology_utils.initial_host_info,)
186187
self._host_list_provider_service.initial_connection_host_info = self._topology_utils.initial_host_info
187188

188189
self._rds_url_type: RdsUrlType = self._rds_utils.identify_rds_type(self._topology_utils.initial_host_info.host)
@@ -210,7 +211,10 @@ def _initialize(self):
210211
self._is_initialized = True
211212

212213
def _get_suggested_cluster_id(self, url: str) -> Optional[ClusterIdSuggestion]:
213-
for key, hosts in RdsHostListProvider._topology_cache.get_dict().items():
214+
topology_cache = StorageService.get_all(Topology)
215+
if topology_cache is None:
216+
return None
217+
for key, hosts in topology_cache.get_dict().items():
214218
is_primary_cluster_id = \
215219
RdsHostListProvider._is_primary_cluster_id_cache.get_with_default(
216220
key, False, self._suggested_cluster_id_refresh_ns)
@@ -244,7 +248,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
244248
self._cluster_id = suggested_primary_cluster_id
245249
self._is_primary_cluster_id = True
246250

247-
cached_hosts = RdsHostListProvider._topology_cache.get(self._cluster_id)
251+
cached_hosts = StorageService.get(Topology, self._cluster_id)
248252
if not cached_hosts or force_update:
249253
if not conn:
250254
# Cannot fetch topology without a connection
@@ -255,7 +259,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
255259
driver_dialect = self._host_list_provider_service.driver_dialect
256260
hosts = self.query_for_topology(conn, driver_dialect)
257261
if hosts is not None and len(hosts) > 0:
258-
RdsHostListProvider._topology_cache.put(self._cluster_id, hosts, self._refresh_rate_ns)
262+
StorageService.set(self._cluster_id, hosts, Topology)
259263
if self._is_primary_cluster_id and cached_hosts is None:
260264
# This cluster_id is primary and a new entry was just created in the cache. When this happens,
261265
# we check for non-primary cluster IDs associated with the same cluster so that the topology
@@ -270,14 +274,18 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
270274
else:
271275
return RdsHostListProvider.FetchTopologyResult(self._initial_hosts, False)
272276

273-
def query_for_topology(self, conn, driver_dialect) -> Optional[Tuple[HostInfo, ...]]:
277+
def query_for_topology(self, conn, driver_dialect) -> Optional[Topology]:
274278
return self._topology_utils.query_for_topology(conn, driver_dialect)
275279

276-
def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
280+
def _suggest_cluster_id(self, primary_cluster_id_hosts: Topology):
277281
if not primary_cluster_id_hosts:
278-
return
282+
return None
279283

280-
for cluster_id, hosts in RdsHostListProvider._topology_cache.get_dict().items():
284+
topology_cache = StorageService.get_all(Topology)
285+
if topology_cache is None:
286+
return None
287+
288+
for cluster_id, hosts in topology_cache.get_dict().items():
281289
is_primary_cluster = RdsHostListProvider._is_primary_cluster_id_cache.get_with_default(
282290
cluster_id, False, self._suggested_cluster_id_refresh_ns)
283291
suggested_primary_cluster_id = RdsHostListProvider._cluster_ids_to_update.get(cluster_id)
@@ -293,8 +301,9 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
293301
RdsHostListProvider._cluster_ids_to_update.put(
294302
cluster_id, self._cluster_id, self._suggested_cluster_id_refresh_ns)
295303
break
304+
return None
296305

297-
def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
306+
def refresh(self, connection: Optional[Connection] = None) -> Topology:
298307
"""
299308
Get topology information for the database cluster.
300309
This method executes a database query if there is no information for the cluster in the cache, or if the cached topology is outdated.
@@ -311,7 +320,7 @@ def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ..
311320
self._hosts = topology.hosts
312321
return tuple(self._hosts)
313322

314-
def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
323+
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
315324
"""
316325
Execute a database query to retrieve information for the current cluster topology. Any cached topology information will be ignored.
317326
@@ -327,7 +336,7 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn
327336
self._hosts = topology.hosts
328337
return tuple(self._hosts)
329338

330-
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
339+
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
331340
raise AwsWrapperError(
332341
Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "RdsHostListProvider"))
333342

@@ -385,7 +394,7 @@ class ClusterIdSuggestion:
385394

386395
@dataclass()
387396
class FetchTopologyResult:
388-
hosts: Tuple[HostInfo, ...]
397+
hosts: Topology
389398
is_cached_data: bool
390399

391400

@@ -394,7 +403,7 @@ class ConnectionStringHostListProvider(StaticHostListProvider):
394403
def __init__(self, host_list_provider_service: HostListProviderService, props: Properties):
395404
self._host_list_provider_service: HostListProviderService = host_list_provider_service
396405
self._props: Properties = props
397-
self._hosts: Tuple[HostInfo, ...] = ()
406+
self._hosts: Topology = ()
398407
self._is_initialized: bool = False
399408
self._initial_host_info: Optional[HostInfo] = None
400409

@@ -412,15 +421,15 @@ def _initialize(self):
412421
self._host_list_provider_service.initial_connection_host_info = self._initial_host_info
413422
self._is_initialized = True
414423

415-
def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
424+
def refresh(self, connection: Optional[Connection] = None) -> Topology:
416425
self._initialize()
417426
return tuple(self._hosts)
418427

419-
def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
428+
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
420429
self._initialize()
421430
return tuple(self._hosts)
422431

423-
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
432+
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
424433
raise AwsWrapperError(
425434
Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "ConnectionStringHostListProvider"))
426435

@@ -499,7 +508,7 @@ def query_for_topology(
499508
self,
500509
conn: Connection,
501510
driver_dialect: DriverDialect,
502-
) -> Optional[Tuple[HostInfo, ...]]:
511+
) -> Optional[Topology]:
503512
"""
504513
Query the database for topology information.
505514
@@ -513,7 +522,7 @@ def query_for_topology(
513522
return x
514523

515524
@abstractmethod
516-
def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
525+
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
517526
pass
518527

519528
def _create_host(self, record: Tuple) -> HostInfo:
@@ -628,7 +637,7 @@ class AuroraTopologyUtils(TopologyUtils):
628637

629638
_executor_name: ClassVar[str] = "AuroraTopologyUtils"
630639

631-
def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
640+
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
632641
"""
633642
Query the database for topology information.
634643
@@ -643,7 +652,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
643652
except ProgrammingError as e:
644653
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e
645654

646-
def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]:
655+
def _process_query_results(self, cursor: Cursor) -> Topology:
647656
"""
648657
Form a list of hosts from the results of the topology query.
649658
:param cursor: The Cursor object containing a reference to the results of the topology query.
@@ -692,7 +701,7 @@ def __init__(
692701
self._writer_host_query = writer_host_query
693702
self._writer_host_column_index = writer_host_column_index
694703

695-
def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
704+
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
696705
try:
697706
with closing(conn.cursor()) as cursor:
698707
cursor.execute(self._writer_host_query)
@@ -709,7 +718,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
709718
except ProgrammingError as e:
710719
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e
711720

712-
def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Tuple[HostInfo, ...]:
721+
def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Topology:
713722
hosts_dict = {}
714723
for record in cursor:
715724
host: HostInfo = self._create_multi_az_host(record, writer_id)
@@ -789,7 +798,7 @@ def _get_monitor(self) -> Optional[ClusterTopologyMonitor]:
789798
self._high_refresh_rate_ns
790799
), MonitoringRdsHostListProvider._MONITOR_CLEANUP_NANO)
791800

792-
def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Tuple[HostInfo, ...]]:
801+
def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Topology]:
793802
monitor = self._get_monitor()
794803

795804
if monitor is None:
@@ -800,7 +809,7 @@ def query_for_topology(self, connection: Connection, driver_dialect) -> Optional
800809
except TimeoutError:
801810
return None
802811

803-
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
812+
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
804813
monitor = self._get_monitor()
805814

806815
if monitor is None:

0 commit comments

Comments
 (0)