Skip to content

Commit d13b509

Browse files
fix: move class var ThreadPoolExecutors to container (#1066)
1 parent f3b7b68 commit d13b509

43 files changed

Lines changed: 850 additions & 472 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aws_advanced_python_wrapper/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from logging import DEBUG, getLogger
1616

17+
from .cleanup import release_resources
1718
from .utils.utils import LogUtils
1819
from .wrapper import AwsWrapperConnection
1920

@@ -23,6 +24,17 @@
2324
threadsafety = 2
2425
paramstyle = "pyformat"
2526

27+
# Public API
28+
__all__ = [
29+
'connect',
30+
'AwsWrapperConnection',
31+
'release_resources',
32+
'set_logger',
33+
'apilevel',
34+
'threadsafety',
35+
'paramstyle'
36+
]
37+
2638

2739
def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None):
2840
LogUtils.setup_logger(getLogger(name), level, format_string)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from aws_advanced_python_wrapper.host_monitoring_plugin import \
16+
MonitoringThreadContainer
17+
from aws_advanced_python_wrapper.thread_pool_container import \
18+
ThreadPoolContainer
19+
20+
21+
def release_resources() -> None:
22+
"""Release all global resources used by the wrapper."""
23+
MonitoringThreadContainer.clean_up()
24+
ThreadPoolContainer.release_resources()

aws_advanced_python_wrapper/custom_endpoint_plugin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def _run(self):
169169
len(endpoints),
170170
endpoint_hostnames)
171171

172-
sleep(self._refresh_rate_ns / 1_000_000_000)
172+
if self._stop_event.wait(self._refresh_rate_ns / 1_000_000_000):
173+
break
173174
continue
174175

175176
endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0])
@@ -178,7 +179,8 @@ def _run(self):
178179
if cached_info is not None and cached_info == endpoint_info:
179180
elapsed_time = perf_counter_ns() - start_ns
180181
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
181-
sleep(sleep_duration / 1_000_000_000)
182+
if self._stop_event.wait(sleep_duration / 1_000_000_000):
183+
break
182184
continue
183185

184186
logger.debug(
@@ -196,7 +198,8 @@ def _run(self):
196198

197199
elapsed_time = perf_counter_ns() - start_ns
198200
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
199-
sleep(sleep_duration / 1_000_000_000)
201+
if self._stop_event.wait(sleep_duration / 1_000_000_000):
202+
break
200203
continue
201204
except InterruptedError as e:
202205
raise e

aws_advanced_python_wrapper/database_dialect.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .exception_handling import ExceptionHandler
2929

3030
from abc import ABC, abstractmethod
31-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
31+
from concurrent.futures import TimeoutError
3232
from contextlib import closing
3333
from enum import Enum, auto
3434

@@ -37,6 +37,8 @@
3737
from aws_advanced_python_wrapper.host_list_provider import (
3838
ConnectionStringHostListProvider, RdsHostListProvider)
3939
from aws_advanced_python_wrapper.hostinfo import HostInfo
40+
from aws_advanced_python_wrapper.thread_pool_container import \
41+
ThreadPoolContainer
4042
from aws_advanced_python_wrapper.utils.decorators import \
4143
preserve_transaction_status_with_timeout
4244
from aws_advanced_python_wrapper.utils.log import Logger
@@ -638,7 +640,7 @@ class DatabaseDialectManager(DatabaseDialectProvider):
638640
_ENDPOINT_CACHE_EXPIRATION_NS = 30 * 60_000_000_000 # 30 minutes
639641
_known_endpoint_dialects: CacheMap[str, DialectCode] = CacheMap()
640642
_custom_dialect: Optional[DatabaseDialect] = None
641-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DatabaseDialectManagerExecutor")
643+
_executor_name: ClassVar[str] = "DatabaseDialectManagerExecutor"
642644
_known_dialects_by_code: Dict[DialectCode, DatabaseDialect] = {
643645
DialectCode.MYSQL: MysqlDatabaseDialect(),
644646
DialectCode.RDS_MYSQL: RdsMysqlDialect(),
@@ -776,7 +778,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
776778
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
777779
try:
778780
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
779-
DatabaseDialectManager._executor,
781+
ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name),
780782
timeout_sec,
781783
driver_dialect,
782784
conn)(dialect_candidate.is_dialect)

aws_advanced_python_wrapper/driver_dialect.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from aws_advanced_python_wrapper.pep249 import Connection, Cursor
2222

2323
from abc import ABC
24-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
24+
from concurrent.futures import TimeoutError
2525

2626
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
2727
from aws_advanced_python_wrapper.errors import (QueryTimeoutError,
2828
UnsupportedOperationError)
29+
from aws_advanced_python_wrapper.thread_pool_container import \
30+
ThreadPoolContainer
2931
from aws_advanced_python_wrapper.utils.decorators import timeout
3032
from aws_advanced_python_wrapper.utils.messages import Messages
3133
from aws_advanced_python_wrapper.utils.properties import (Properties,
@@ -40,7 +42,7 @@ class DriverDialect(ABC):
4042
_QUERY = "SELECT 1"
4143
_ALL_METHODS = "*"
4244

43-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DriverDialectExecutor")
45+
_executor_name: ClassVar[str] = "DriverDialectExecutor"
4446
_dialect_code: str = DriverDialectCodes.GENERIC
4547
_network_bound_methods: Set[str] = {_ALL_METHODS}
4648
_read_only: bool = False
@@ -136,7 +138,7 @@ def execute(
136138

137139
if exec_timeout > 0:
138140
try:
139-
execute_with_timeout = timeout(DriverDialect._executor, exec_timeout)(exec_func)
141+
execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func)
140142
return execute_with_timeout()
141143
except TimeoutError as e:
142144
raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e

aws_advanced_python_wrapper/host_list_provider.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import uuid
1818
from abc import ABC, abstractmethod
19-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
19+
from concurrent.futures import TimeoutError
2020
from contextlib import closing
2121
from dataclasses import dataclass
2222
from datetime import datetime
@@ -39,6 +39,8 @@
3939
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
4040
from aws_advanced_python_wrapper.pep249 import (Connection, Cursor,
4141
ProgrammingError)
42+
from aws_advanced_python_wrapper.thread_pool_container import \
43+
ThreadPoolContainer
4244
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
4345
from aws_advanced_python_wrapper.utils.log import Logger
4446
from aws_advanced_python_wrapper.utils.messages import Messages
@@ -148,8 +150,6 @@ class RdsHostListProvider(DynamicHostListProvider, HostListProvider):
148150
# cluster IDs so that connections to the same clusters can share topology info.
149151
_cluster_ids_to_update: CacheMap[str, str] = CacheMap()
150152

151-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="RdsHostListProviderExecutor")
152-
153153
def __init__(self, host_list_provider_service: HostListProviderService, props: Properties, topology_utils: TopologyUtils):
154154
self._host_list_provider_service: HostListProviderService = host_list_provider_service
155155
self._props: Properties = props
@@ -425,6 +425,8 @@ class TopologyUtils(ABC):
425425
to various database engine deployments (e.g. Aurora, Multi-AZ, etc.).
426426
"""
427427

428+
_executor_name: ClassVar[str] = "TopologyUtils"
429+
428430
def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Properties):
429431
self._dialect: db_dialect.TopologyAwareDatabaseDialect = dialect
430432
self._rds_utils = RdsUtils()
@@ -487,7 +489,7 @@ def query_for_topology(
487489
an empty tuple will be returned.
488490
"""
489491
query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout(
490-
RdsHostListProvider._executor, self._max_timeout, driver_dialect, conn)(self._query_for_topology)
492+
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, conn)(self._query_for_topology)
491493
return query_for_topology_func_with_timeout(conn)
492494

493495
@abstractmethod
@@ -549,7 +551,7 @@ def create_host(
549551
def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole:
550552
try:
551553
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
552-
RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_role)
554+
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_role)
553555
result = cursor_execute_func_with_timeout(connection)
554556
if result is not None:
555557
is_reader = result[0]
@@ -572,7 +574,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) ->
572574
"""
573575

574576
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
575-
RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_id)
577+
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_id)
576578
result = cursor_execute_func_with_timeout(connection)
577579
if result:
578580
host_id: str = result[0]
@@ -586,6 +588,9 @@ def _get_host_id(self, conn: Connection):
586588

587589

588590
class AuroraTopologyUtils(TopologyUtils):
591+
592+
_executor_name: ClassVar[str] = "AuroraTopologyUtils"
593+
589594
def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
590595
"""
591596
Query the database for topology information.
@@ -636,6 +641,9 @@ def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]:
636641

637642

638643
class MultiAzTopologyUtils(TopologyUtils):
644+
645+
_executor_name: ClassVar[str] = "MultiAzTopologyUtils"
646+
639647
def __init__(
640648
self,
641649
dialect: db_dialect.TopologyAwareDatabaseDialect,

aws_advanced_python_wrapper/host_monitoring_plugin.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222
from aws_advanced_python_wrapper.pep249 import Connection
2323
from aws_advanced_python_wrapper.plugin_service import PluginService
2424

25-
from concurrent.futures import (Executor, Future, ThreadPoolExecutor,
26-
TimeoutError)
25+
from concurrent.futures import Future, TimeoutError
2726
from dataclasses import dataclass
2827
from queue import Queue
2928
from threading import Event, Lock, RLock
30-
from time import perf_counter_ns, sleep
29+
from time import perf_counter_ns
3130
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Optional, Set
3231

3332
from _weakref import ReferenceType, ref
@@ -36,6 +35,8 @@
3635
from aws_advanced_python_wrapper.host_availability import HostAvailability
3736
from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin,
3837
PluginFactory)
38+
from aws_advanced_python_wrapper.thread_pool_container import \
39+
ThreadPoolContainer
3940
from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict
4041
from aws_advanced_python_wrapper.utils.log import Logger
4142
from aws_advanced_python_wrapper.utils.messages import Messages
@@ -548,9 +549,8 @@ def _execute_conn_check(self, conn: Connection, timeout_sec: float):
548549
driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec)
549550
cursor.fetchone()
550551

551-
# Used to help with testing
552552
def sleep(self, duration: int):
553-
sleep(duration)
553+
self._is_stopped.wait(duration)
554554

555555

556556
class MonitoringThreadContainer:
@@ -565,7 +565,7 @@ class MonitoringThreadContainer:
565565

566566
_monitor_map: ConcurrentDict[str, Monitor] = ConcurrentDict()
567567
_tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict()
568-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor")
568+
_executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor"
569569

570570
# This logic ensures that this class is a Singleton
571571
def __new__(cls, *args, **kwargs):
@@ -593,7 +593,9 @@ def _get_or_create_monitor(_) -> Monitor:
593593
if supplied_monitor is None:
594594
raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone"))
595595
self._tasks_map.compute_if_absent(
596-
supplied_monitor, lambda _: MonitoringThreadContainer._executor.submit(supplied_monitor.run))
596+
supplied_monitor,
597+
lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name)
598+
.submit(supplied_monitor.run))
597599
return supplied_monitor
598600

599601
if monitor is None:
@@ -648,12 +650,9 @@ def _release_resources(self):
648650
for monitor, _ in self._tasks_map.items():
649651
monitor.stop()
650652

653+
ThreadPoolContainer.release_pool(MonitoringThreadContainer._executor_name, wait=False)
651654
self._tasks_map.clear()
652655

653-
# Reset the executor.
654-
self._executor.shutdown(wait=False)
655-
MonitoringThreadContainer._executor = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor")
656-
657656

658657
class MonitorService:
659658
def __init__(self, plugin_service: PluginService):

aws_advanced_python_wrapper/mysql_driver_dialect.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
from aws_advanced_python_wrapper.hostinfo import HostInfo
2121
from aws_advanced_python_wrapper.pep249 import Connection
2222

23-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
23+
from concurrent.futures import TimeoutError
2424
from inspect import signature
2525

2626
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2727
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
2828
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
29+
from aws_advanced_python_wrapper.thread_pool_container import \
30+
ThreadPoolContainer
2931
from aws_advanced_python_wrapper.utils.decorators import timeout
3032
from aws_advanced_python_wrapper.utils.messages import Messages
3133
from aws_advanced_python_wrapper.utils.properties import (Properties,
@@ -55,7 +57,7 @@ class MySQLDriverDialect(DriverDialect):
5557
AUTH_METHOD = "mysql_clear_password"
5658
IS_CLOSED_TIMEOUT_SEC = 3
5759

58-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor")
60+
_executor_name: ClassVar[str] = "MySQLDriverDialectExecutor"
5961

6062
_dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON
6163
_network_bound_methods: Set[str] = {
@@ -94,7 +96,8 @@ def is_closed(self, conn: Connection) -> bool:
9496
if self.can_execute_query(conn):
9597
socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props)
9698
timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC
97-
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) # type: ignore
99+
is_connected_with_timeout = timeout(
100+
ThreadPoolContainer.get_thread_pool(MySQLDriverDialect._executor_name), timeout_sec)(conn.is_connected) # type: ignore
98101

99102
try:
100103
return not is_connected_with_timeout()

aws_advanced_python_wrapper/plugin_service.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
4242

4343
from abc import abstractmethod
44-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
44+
from concurrent.futures import TimeoutError
4545
from contextlib import closing
4646
from typing import (Any, Callable, Dict, FrozenSet, Optional, Protocol, Set,
4747
Tuple)
@@ -85,6 +85,8 @@
8585
from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \
8686
SimpleReadWriteSplittingPluginFactory
8787
from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory
88+
from aws_advanced_python_wrapper.thread_pool_container import \
89+
ThreadPoolContainer
8890
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
8991
from aws_advanced_python_wrapper.utils.decorators import \
9092
preserve_transaction_status_with_timeout
@@ -314,7 +316,7 @@ class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResour
314316
_host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap()
315317
_status_cache: ClassVar[CacheMap[str, Any]] = CacheMap()
316318

317-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="PluginServiceImplExecutor")
319+
_executor_name: ClassVar[str] = "PluginServiceImplExecutor"
318320

319321
def __init__(
320322
self,
@@ -611,7 +613,7 @@ def fill_aliases(self, connection: Optional[Connection] = None, host_info: Optio
611613
try:
612614
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
613615
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
614-
PluginServiceImpl._executor, timeout_sec, driver_dialect, connection)(self._fill_aliases)
616+
ThreadPoolContainer.get_thread_pool(PluginServiceImpl._executor_name), timeout_sec, driver_dialect, connection)(self._fill_aliases)
615617
cursor_execute_func_with_timeout(connection, host_info)
616618
except TimeoutError as e:
617619
raise QueryTimeoutError(Messages.get("PluginServiceImpl.FillAliasesTimeout")) from e

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ HostResponseTimeMonitor.OpeningConnection=[HostResponseTimeMonitor] Opening a Re
274274
HostResponseTimeMonitor.ResponseTime=[HostResponseTimeMonitor] Response time for '{}': {} ms
275275
HostResponseTimeMonitor.Stopped=[HostResponseTimeMonitor] Stopped Response time thread for host '{}'.
276276

277+
ThreadPoolContainer.ErrorShuttingDownPool=[ThreadPoolContainer] Error shutting down pool '{}': '{}'.
278+
277279
OpenedConnectionTracker.OpenedConnectionsTracked=[OpenedConnectionTracker] Opened Connections Tracked: {}
278280
OpenedConnectionTracker.InvalidatingConnections=[OpenedConnectionTracker] Invalidating opened connections to host: {}
279281
OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet=[OpenedConnectionTracker] The driver is unable to track this opened connection because the instance endpoint is unknown.

0 commit comments

Comments
 (0)