Skip to content

Commit 8447fc9

Browse files
fix: class var threads (#1090)
1 parent 8eca0a0 commit 8447fc9

24 files changed

Lines changed: 571 additions & 176 deletions

aws_advanced_python_wrapper/cleanup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
MonitoringThreadContainer
1919
from aws_advanced_python_wrapper.thread_pool_container import \
2020
ThreadPoolContainer
21+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
22+
SlidingExpirationCacheContainer
2123

2224

2325
def release_resources() -> None:
2426
"""Release all global resources used by the wrapper."""
2527
MonitoringThreadContainer.clean_up()
2628
ThreadPoolContainer.release_resources()
2729
OpenedConnectionTracker.release_resources()
30+
SlidingExpirationCacheContainer.release_resources()

aws_advanced_python_wrapper/connection_provider.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Protocol, Tuple
17+
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, Optional,
18+
Protocol, Tuple)
1819

1920
if TYPE_CHECKING:
2021
from aws_advanced_python_wrapper.database_dialect import DatabaseDialect
@@ -131,8 +132,8 @@ def connect(
131132

132133

133134
class ConnectionProviderManager:
134-
_lock: Lock = Lock()
135-
_conn_provider: Optional[ConnectionProvider] = None
135+
_lock: ClassVar[Lock] = Lock()
136+
_conn_provider: ClassVar[Optional[ConnectionProvider]] = None
136137

137138
def __init__(self, default_provider: ConnectionProvider = DriverConnectionProvider()):
138139
self._default_provider: ConnectionProvider = default_provider

aws_advanced_python_wrapper/custom_endpoint_plugin.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
from aws_advanced_python_wrapper.utils.log import Logger
4343
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
4444
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
45-
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
46-
SlidingExpirationCacheWithCleanupThread
45+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
46+
SlidingExpirationCacheContainer
4747
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
4848
TelemetryCounter, TelemetryFactory)
4949

@@ -232,11 +232,8 @@ class CustomEndpointPlugin(Plugin):
232232
or removing an instance in the custom endpoint.
233233
"""
234234
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name}
235-
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute
236-
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \
237-
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS,
238-
should_dispose_func=lambda _: True,
239-
item_disposal_func=lambda monitor: monitor.close())
235+
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 60_000_000_000 # 1 minute
236+
_MONITOR_CACHE_NAME: ClassVar[str] = "custom_endpoint_monitors"
240237

241238
def __init__(self, plugin_service: PluginService, props: Properties):
242239
self._plugin_service = plugin_service
@@ -255,6 +252,13 @@ def __init__(self, plugin_service: PluginService, props: Properties):
255252
telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
256253
self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter")
257254

255+
self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
256+
name=CustomEndpointPlugin._MONITOR_CACHE_NAME,
257+
cleanup_interval_ns=CustomEndpointPlugin._CACHE_CLEANUP_RATE_NS,
258+
should_dispose_func=lambda _: True,
259+
item_disposal_func=lambda monitor: monitor.close()
260+
)
261+
258262
CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)
259263

260264
@property
@@ -298,7 +302,7 @@ def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor:
298302
host_info = cast('HostInfo', self._custom_endpoint_host_info)
299303
endpoint_id = cast('str', self._custom_endpoint_id)
300304
region = cast('str', self._region)
301-
monitor = CustomEndpointPlugin._monitors.compute_if_absent(
305+
monitor = self._monitors.compute_if_absent(
302306
host_info.host,
303307
lambda key: CustomEndpointMonitor(
304308
self._plugin_service,

aws_advanced_python_wrapper/database_dialect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,7 @@ def __init__(self, props: Properties, rds_helper: Optional[RdsUtils] = None):
695695
self._can_update: bool = False
696696
self._dialect: DatabaseDialect = UnknownDatabaseDialect()
697697
self._dialect_code: DialectCode = DialectCode.UNKNOWN
698+
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)
698699

699700
@staticmethod
700701
def get_custom_dialect():
@@ -814,7 +815,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
814815
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
815816
try:
816817
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
817-
ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name),
818+
self._thread_pool,
818819
timeout_sec,
819820
driver_dialect,
820821
conn)(dialect_candidate.is_dialect)

aws_advanced_python_wrapper/driver_dialect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class DriverDialect(ABC):
5151

5252
def __init__(self, props: Properties):
5353
self._props = props
54+
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)
5455

5556
@property
5657
def driver_name(self):
@@ -138,7 +139,7 @@ def execute(
138139

139140
if exec_timeout > 0:
140141
try:
141-
execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func)
142+
execute_with_timeout = timeout(self._thread_pool, exec_timeout)(exec_func)
142143
return execute_with_timeout()
143144
except TimeoutError as e:
144145
raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e

aws_advanced_python_wrapper/fastest_response_strategy_plugin.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from aws_advanced_python_wrapper.utils.messages import Messages
3131
from aws_advanced_python_wrapper.utils.properties import (Properties,
3232
WrapperProperties)
33-
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
34-
SlidingExpirationCacheWithCleanupThread
33+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
34+
SlidingExpirationCacheContainer
3535
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
3636
TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel)
3737

@@ -59,7 +59,7 @@ def __init__(self, plugin_service: PluginService, props: Properties):
5959
self._properties = props
6060
self._host_response_time_service: HostResponseTimeService = \
6161
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props))
62-
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6
62+
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 1_000_000
6363
self._random_host_selector = RandomHostSelector()
6464
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
6565
self._hosts: Tuple[HostInfo, ...] = ()
@@ -278,21 +278,29 @@ def _open_connection(self):
278278

279279

280280
class HostResponseTimeService:
281-
_CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes
282-
_CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute
283-
_lock: Lock = Lock()
284-
_monitoring_hosts: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \
285-
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS,
286-
should_dispose_func=lambda monitor: True,
287-
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor))
281+
_CACHE_EXPIRATION_NS: ClassVar[int] = 10 * 60_000_000_000 # 10 minutes
282+
_CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute
283+
_CACHE_NAME: ClassVar[str] = "host_response_time_monitors"
284+
_lock: ClassVar[Lock] = Lock()
288285

289286
def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int):
290287
self._plugin_service = plugin_service
291288
self._properties = props
292289
self._interval_ms = interval_ms
293290
self._hosts: Tuple[HostInfo, ...] = ()
294291
self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
295-
self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge("frt.hosts.count", lambda: len(self._monitoring_hosts))
292+
293+
self._monitoring_hosts = SlidingExpirationCacheContainer.get_or_create_cache(
294+
name=HostResponseTimeService._CACHE_NAME,
295+
cleanup_interval_ns=HostResponseTimeService._CACHE_CLEANUP_NS,
296+
should_dispose_func=lambda monitor: True,
297+
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor)
298+
)
299+
300+
self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge(
301+
"frt.hosts.count",
302+
lambda: len(self._monitoring_hosts)
303+
)
296304

297305
@property
298306
def hosts(self) -> Tuple[HostInfo, ...]:
@@ -310,7 +318,7 @@ def _monitor_close(monitor: HostResponseTimeMonitor):
310318
pass
311319

312320
def get_response_time(self, host_info: HostInfo) -> int:
313-
monitor: Optional[HostResponseTimeMonitor] = HostResponseTimeService._monitoring_hosts.get(host_info.url)
321+
monitor: Optional[HostResponseTimeMonitor] = self._monitoring_hosts.get(host_info.url)
314322
if monitor is None:
315323
return MAX_VALUE
316324
return monitor.response_time
@@ -327,4 +335,4 @@ def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None:
327335
self._plugin_service,
328336
host,
329337
self._properties,
330-
self._interval_ms), self._CACHE_EXPIRATION_NS)
338+
self._interval_ms), HostResponseTimeService._CACHE_EXPIRATION_NS)

aws_advanced_python_wrapper/host_list_provider.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
ClusterTopologyMonitor, ClusterTopologyMonitorImpl)
2929
from aws_advanced_python_wrapper.utils.decorators import \
3030
preserve_transaction_status_with_timeout
31-
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
32-
SlidingExpirationCacheWithCleanupThread
31+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
32+
SlidingExpirationCacheContainer
3333

3434
if TYPE_CHECKING:
3535
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
@@ -476,6 +476,7 @@ def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Prop
476476

477477
self.instance_template: HostInfo = instance_template
478478
self._max_timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props)
479+
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)
479480

480481
def _validate_host_pattern(self, host: str):
481482
if not self._rds_utils.is_dns_pattern_valid(host):
@@ -507,7 +508,7 @@ def query_for_topology(
507508
an empty tuple will be returned.
508509
"""
509510
query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout(
510-
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology)
511+
self._thread_pool, self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology)
511512
x = query_for_topology_func_with_timeout(conn)
512513
return x
513514

@@ -570,7 +571,7 @@ def create_host(
570571
def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole:
571572
try:
572573
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
573-
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_role)
574+
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_role)
574575
result = cursor_execute_func_with_timeout(connection)
575576
if result is not None:
576577
is_reader = result[0]
@@ -593,7 +594,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) ->
593594
"""
594595

595596
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
596-
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_id)
597+
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_id)
597598
result = cursor_execute_func_with_timeout(connection)
598599
if result:
599600
host_id: str = result[0]
@@ -608,7 +609,7 @@ def _get_host_id(self, conn: Connection):
608609
def get_writer_host_if_connected(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]:
609610
try:
610611
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
611-
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id)
612+
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id)
612613
result = cursor_execute_func_with_timeout(connection)
613614
if result:
614615
host_id: str = result[0]
@@ -752,13 +753,9 @@ def _create_multi_az_host(self, record: Tuple, writer_id: str) -> HostInfo:
752753

753754

754755
class MonitoringRdsHostListProvider(RdsHostListProvider):
755-
_CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 # 1 minute
756-
_MONITOR_CLEANUP_NANO = 15 * 60 * 1_000_000_000 # 15 minutes
757-
758-
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, ClusterTopologyMonitor]] = \
759-
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO,
760-
should_dispose_func=lambda monitor: monitor.can_dispose(),
761-
item_disposal_func=lambda monitor: monitor.close())
756+
_CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 # 1 minute
757+
_MONITOR_CLEANUP_NANO: ClassVar[int] = 15 * 60 * 1_000_000_000 # 15 minutes
758+
_MONITOR_CACHE_NAME: ClassVar[str] = "cluster_topology_monitors"
762759

763760
def __init__(
764761
self,
@@ -772,6 +769,13 @@ def __init__(
772769
self._high_refresh_rate_ns = (
773770
WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get_int(self._props) * 1_000_000)
774771

772+
self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
773+
name=MonitoringRdsHostListProvider._MONITOR_CACHE_NAME,
774+
cleanup_interval_ns=MonitoringRdsHostListProvider._CACHE_CLEANUP_NANO,
775+
should_dispose_func=lambda monitor: monitor.can_dispose(),
776+
item_disposal_func=lambda monitor: monitor.close()
777+
)
778+
775779
def _get_monitor(self) -> Optional[ClusterTopologyMonitor]:
776780
return self._monitors.compute_if_absent_with_disposal(self.get_cluster_id(),
777781
lambda k: ClusterTopologyMonitorImpl(
@@ -803,7 +807,3 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int)
803807
return ()
804808

805809
return monitor.force_refresh(should_verify_writer, timeout_sec)
806-
807-
@staticmethod
808-
def release_resources():
809-
MonitoringRdsHostListProvider._monitors.clear()

aws_advanced_python_wrapper/host_monitoring_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,9 @@ class MonitoringThreadContainer:
578578
_tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict()
579579
_executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor"
580580

581+
def __init__(self):
582+
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)
583+
581584
# This logic ensures that this class is a Singleton
582585
def __new__(cls, *args, **kwargs):
583586
if cls._instance is None:
@@ -605,8 +608,7 @@ def _get_or_create_monitor(_) -> Monitor:
605608
raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone"))
606609
self._tasks_map.compute_if_absent(
607610
supplied_monitor,
608-
lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name)
609-
.submit(supplied_monitor.run))
611+
lambda _: self._thread_pool.submit(supplied_monitor.run))
610612
return supplied_monitor
611613

612614
if monitor is None:

aws_advanced_python_wrapper/host_monitoring_v2_plugin.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
PropertiesUtils,
3737
WrapperProperties)
3838
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
39-
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
40-
SlidingExpirationCacheWithCleanupThread
39+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
40+
SlidingExpirationCacheContainer
4141
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
4242
TelemetryCounter, TelemetryFactory, TelemetryTraceLevel)
4343

@@ -450,19 +450,22 @@ def close(self) -> None:
450450

451451
class MonitorServiceV2:
452452
# 1 Minute to Nanoseconds
453-
_CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000
454-
455-
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostMonitorV2]] = \
456-
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO,
457-
should_dispose_func=lambda monitor: monitor.can_dispose(),
458-
item_disposal_func=lambda monitor: monitor.close())
453+
_CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000
454+
_MONITOR_CACHE_NAME: ClassVar[str] = "host_monitors_v2"
459455

460456
def __init__(self, plugin_service: PluginService):
461457
self._plugin_service: PluginService = plugin_service
462458

463459
telemetry_factory = self._plugin_service.get_telemetry_factory()
464460
self._aborted_connections_counter = telemetry_factory.create_counter("efm2.connections.aborted")
465461

462+
self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
463+
name=MonitorServiceV2._MONITOR_CACHE_NAME,
464+
cleanup_interval_ns=MonitorServiceV2._CACHE_CLEANUP_NANO,
465+
should_dispose_func=lambda monitor: monitor.can_dispose(),
466+
item_disposal_func=lambda monitor: monitor.close()
467+
)
468+
466469
def start_monitoring(
467470
self,
468471
conn: Connection,

0 commit comments

Comments
 (0)