Skip to content

Commit 8eca0a0

Browse files
authored
fix: update aurora connection tracker and fix writer host comparison (#1081)
1 parent 216b956 commit 8eca0a0

22 files changed

Lines changed: 267 additions & 134 deletions

aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py

Lines changed: 179 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414

1515
from __future__ import annotations
1616

17+
import threading
1718
from threading import Thread
18-
from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional,
19-
Set, Tuple)
19+
from time import perf_counter_ns
20+
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet,
21+
Optional, Set)
22+
23+
from aws_advanced_python_wrapper.utils.notifications import HostEvent
24+
from aws_advanced_python_wrapper.utils.utils import Utils
2025

2126
if TYPE_CHECKING:
2227
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2328
from aws_advanced_python_wrapper.plugin_service import PluginService
29+
from aws_advanced_python_wrapper.hostinfo import HostInfo
2430
from aws_advanced_python_wrapper.pep249 import Connection
2531

2632
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
@@ -29,7 +35,6 @@
2935
from _weakrefset import WeakSet
3036

3137
from aws_advanced_python_wrapper.errors import FailoverError
32-
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
3338
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
3439
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
3540
from aws_advanced_python_wrapper.utils.log import Logger
@@ -39,8 +44,74 @@
3944

4045

4146
class OpenedConnectionTracker:
42-
_opened_connections: Dict[str, WeakSet] = {}
43-
_rds_utils = RdsUtils()
47+
_opened_connections: ClassVar[Dict[str, WeakSet]] = {}
48+
_lock: ClassVar[threading.Lock] = threading.Lock()
49+
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
50+
_prune_thread: ClassVar[Optional[Thread]] = None
51+
_prune_thread_started: ClassVar[bool] = False
52+
_shutdown_event: ClassVar[threading.Event] = threading.Event()
53+
_safe_to_check_closed_classes: ClassVar[Set[str]] = {"psycopg"}
54+
_default_sleep_time: ClassVar[int] = 30
55+
56+
@classmethod
57+
def _start_prune_thread(cls):
58+
with cls._lock:
59+
if not cls._prune_thread_started:
60+
cls._prune_thread_started = True
61+
cls._prune_thread = Thread(daemon=True, target=cls._prune_connections_loop)
62+
cls._prune_thread.start()
63+
64+
@classmethod
65+
def release_resources(cls):
66+
cls._shutdown_event.set()
67+
with cls._lock:
68+
thread_to_join = cls._prune_thread
69+
if thread_to_join is not None:
70+
thread_to_join.join()
71+
with cls._lock:
72+
cls._opened_connections.clear()
73+
74+
@classmethod
75+
def _prune_connections_loop(cls):
76+
while not cls._shutdown_event.is_set():
77+
try:
78+
cls._prune_connections()
79+
if cls._shutdown_event.wait(timeout=cls._default_sleep_time):
80+
break
81+
except Exception:
82+
pass
83+
84+
@classmethod
85+
def _prune_connections(cls):
86+
with cls._lock:
87+
opened_connections = list(cls._opened_connections.items())
88+
89+
to_remove_by_host = {}
90+
for host, conn_set in opened_connections:
91+
to_remove = []
92+
for conn in list(conn_set):
93+
if conn is None:
94+
to_remove.append(conn)
95+
else:
96+
try:
97+
# The following classes do not check connection validity via a DB server call
98+
# so it is safe to check whether connection is already closed.
99+
if any(safe_class in conn.__module__ for safe_class in cls._safe_to_check_closed_classes) and conn.is_closed():
100+
to_remove.append(conn)
101+
except Exception:
102+
pass
103+
104+
if to_remove:
105+
to_remove_by_host[host] = (conn_set, to_remove)
106+
107+
with cls._lock:
108+
for host, (conn_set, to_remove) in to_remove_by_host.items():
109+
for conn in to_remove:
110+
conn_set.discard(conn)
111+
112+
# Remove empty connection sets
113+
if not conn_set and host in cls._opened_connections:
114+
del cls._opened_connections[host]
44115

45116
def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
46117
"""
@@ -56,8 +127,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
56127
self._track_connection(host_info.as_alias(), conn)
57128
return
58129

59-
instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))),
60-
None)
130+
instance_endpoint: Optional[str] = next(
131+
(alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None)
61132
if not instance_endpoint:
62133
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
63134
return
@@ -73,7 +144,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
73144
"""
74145

75146
if host_info:
76-
self.invalidate_all_connections(host=frozenset(host_info.as_alias()))
147+
self.invalidate_all_connections(host=frozenset([host_info.as_alias()]))
77148
self.invalidate_all_connections(host=host_info.as_aliases())
78149
return
79150

@@ -89,27 +160,42 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
89160
if not instance_endpoint:
90161
return
91162

92-
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
93-
if connection_set is not None:
163+
with self._lock:
164+
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
165+
connections_list = list(connection_set) if connection_set is not None else None
166+
167+
if connections_list is not None:
94168
self._log_connection_set(instance_endpoint, connection_set)
95-
self._invalidate_connections(connection_set)
169+
self._invalidate_connections(connections_list)
96170

97-
def _track_connection(self, instance_endpoint: str, conn: Connection):
98-
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
99-
if connection_set is None:
100-
connection_set = WeakSet()
101-
connection_set.add(conn)
102-
self._opened_connections[instance_endpoint] = connection_set
171+
def remove_connection_tracking(self, host_info: HostInfo, connection: Connection | None):
172+
if not connection:
173+
return
174+
175+
if self._rds_utils.is_rds_instance(host_info.host):
176+
host = host_info.as_alias()
103177
else:
104-
connection_set.add(conn)
178+
host = next((alias for alias in host_info.as_aliases()
179+
if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), "")
180+
181+
if not host:
182+
return
183+
184+
with self._lock:
185+
connection_set = self._opened_connections.get(host)
186+
if connection_set:
187+
connection_set.discard(connection)
105188

189+
def _track_connection(self, instance_endpoint: str, conn: Connection):
190+
with self._lock:
191+
connection_set = self._opened_connections.setdefault(instance_endpoint, WeakSet())
192+
connection_set.add(conn)
193+
self._start_prune_thread()
106194
self.log_opened_connections()
107195

108196
@staticmethod
109-
def _task(connection_set: WeakSet):
110-
while connection_set is not None and len(connection_set) > 0:
111-
conn_reference = connection_set.pop()
112-
197+
def _task(connections_list: list):
198+
for conn_reference in connections_list:
113199
if conn_reference is None:
114200
continue
115201

@@ -119,37 +205,38 @@ def _task(connection_set: WeakSet):
119205
# Swallow this exception, current connection should be useless anyway
120206
pass
121207

122-
def _invalidate_connections(self, connection_set: WeakSet):
208+
def _invalidate_connections(self, connections_list: list):
123209
invalidate_connection_thread: Thread = Thread(daemon=True, target=self._task,
124-
args=[connection_set]) # type: ignore
210+
args=[connections_list]) # type: ignore
125211
invalidate_connection_thread.start()
126212

127213
def log_opened_connections(self):
128-
msg = ""
129-
for key, conn_set in self._opened_connections.items():
130-
conn = ""
131-
for item in list(conn_set):
132-
conn += f"\n\t\t{item}"
214+
with self._lock:
215+
opened_connections = [(key, list(conn_set)) for key, conn_set in self._opened_connections.items()]
133216

134-
msg += f"\t[{key} : {conn}]"
217+
msg_parts = []
218+
for key, conn_list in opened_connections:
219+
conn_parts = [f"\n\t\t{item}" for item in conn_list]
220+
conn = "".join(conn_parts)
221+
msg_parts.append(f"\t[{key} : {conn}]")
135222

223+
msg = "".join(msg_parts)
136224
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)
137225

138226
def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
139227
if conn_set is None or len(conn_set) == 0:
140228
return
141229

142-
conn = ""
143-
for item in list(conn_set):
144-
conn += f"\n\t\t{item}"
145-
230+
conn_parts = [f"\n\t\t{item}" for item in list(conn_set)]
231+
conn = "".join(conn_parts)
146232
msg = host + f"[{conn}\n]"
147233
logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg)
148234

149235

150236
class AuroraConnectionTrackerPlugin(Plugin):
151-
_current_writer: Optional[HostInfo] = None
152-
_need_update_current_writer: bool = False
237+
_host_list_refresh_end_time_nano: ClassVar[int] = 0
238+
_refresh_lock: ClassVar[threading.Lock] = threading.Lock()
239+
_TOPOLOGY_CHANGES_EXPECTED_TIME_NANO: ClassVar[int] = 3 * 60 * 1_000_000_000 # 3 minutes
153240

154241
@property
155242
def subscribed_methods(self) -> Set[str]:
@@ -164,6 +251,8 @@ def __init__(self,
164251
self._props = props
165252
self._rds_utils = rds_utils
166253
self._tracker = tracker
254+
self._current_writer: Optional[HostInfo] = None
255+
self._need_update_current_writer: bool = False
167256
self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name,
168257
DbApiMethod.CONNECTION_CLOSE.method_name,
169258
DbApiMethod.CONNECT.method_name,
@@ -192,26 +281,67 @@ def connect(
192281
return conn
193282

194283
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
284+
current_host = self._plugin_service.current_host_info
195285
if self._current_writer is None or self._need_update_current_writer:
196-
self._current_writer = self._get_writer(self._plugin_service.all_hosts)
286+
self._current_writer = Utils.get_writer(self._plugin_service.all_hosts)
197287
self._need_update_current_writer = False
198288

199289
try:
200-
return execute_func()
290+
if not method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
291+
need_refresh_host_lists = False
292+
with AuroraConnectionTrackerPlugin._refresh_lock:
293+
local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano
294+
if local_host_list_refresh_end_time_nano > 0:
295+
if local_host_list_refresh_end_time_nano > perf_counter_ns():
296+
# The time specified in hostListRefreshThresholdTimeNano isn't yet reached.
297+
# Need to continue to refresh host list.
298+
need_refresh_host_lists = True
299+
else:
300+
# The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes
301+
# of host list.
302+
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = 0
303+
304+
if self._need_update_current_writer or need_refresh_host_lists:
305+
# Calling this method may effectively close/abort a current connection
306+
self._check_writer_changed(need_refresh_host_lists)
307+
308+
result = execute_func()
309+
if method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
310+
self._tracker.remove_connection_tracking(current_host, self._plugin_service.current_connection)
311+
return result
201312

202313
except Exception as e:
203-
# Check that e is a FailoverError and that the writer has changed
204-
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer:
205-
self._tracker.invalidate_all_connections(host_info=self._current_writer)
206-
self._tracker.log_opened_connections()
207-
self._need_update_current_writer = True
208-
raise e
314+
if isinstance(e, FailoverError):
315+
with AuroraConnectionTrackerPlugin._refresh_lock:
316+
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = (
317+
perf_counter_ns() + AuroraConnectionTrackerPlugin._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO)
318+
# Calling this method may effectively close/abort a current connection
319+
self._check_writer_changed(True)
320+
raise
321+
322+
def _check_writer_changed(self, need_refresh_host_lists: bool):
323+
if need_refresh_host_lists:
324+
self._plugin_service.refresh_host_list()
325+
326+
host_info_after_failover = Utils.get_writer(self._plugin_service.all_hosts)
327+
if host_info_after_failover is None:
328+
return
329+
330+
if self._current_writer is None:
331+
self._current_writer = host_info_after_failover
332+
self._need_update_current_writer = False
333+
elif not self._current_writer.get_host_and_port() == host_info_after_failover.get_host_and_port():
334+
self._tracker.invalidate_all_connections(host_info=self._current_writer)
335+
self._tracker.log_opened_connections()
336+
self._current_writer = host_info_after_failover
337+
self._need_update_current_writer = False
209338

210-
def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
211-
for host in hosts:
212-
if host.role == HostRole.WRITER:
213-
return host
214-
return None
339+
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
340+
for node, node_changes in changes.items():
341+
if HostEvent.CONVERTED_TO_READER in node_changes:
342+
self._tracker.invalidate_all_connections(host=frozenset([node]))
343+
if HostEvent.CONVERTED_TO_WRITER in node_changes:
344+
self._need_update_current_writer = True
215345

216346

217347
class AuroraConnectionTrackerPluginFactory(PluginFactory):

aws_advanced_python_wrapper/cleanup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \
16+
OpenedConnectionTracker
1517
from aws_advanced_python_wrapper.host_monitoring_plugin import \
1618
MonitoringThreadContainer
1719
from aws_advanced_python_wrapper.thread_pool_container import \
@@ -22,3 +24,4 @@ def release_resources() -> None:
2224
"""Release all global resources used by the wrapper."""
2325
MonitoringThreadContainer.clean_up()
2426
ThreadPoolContainer.release_resources()
27+
OpenedConnectionTracker.release_resources()

aws_advanced_python_wrapper/failover_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,12 @@ def _failover_writer(self):
325325

326326
writer_host = self._get_writer(result.topology)
327327
allowed_hosts = self._plugin_service.hosts
328-
allowed_hostnames = [host.host for host in allowed_hosts]
329-
if writer_host.host not in allowed_hostnames:
328+
allowed_hostnames = [host.get_host_and_port() for host in allowed_hosts]
329+
if writer_host.get_host_and_port() not in allowed_hostnames:
330330
raise FailoverFailedError(
331331
Messages.get_formatted(
332332
"FailoverPlugin.NewWriterNotAllowed",
333-
"<null>" if writer_host is None else writer_host.host,
333+
"<null>" if writer_host is None else writer_host.get_host_and_port(),
334334
LogUtils.log_topology(allowed_hosts)))
335335

336336
self._plugin_service.set_current_connection(result.new_connection, writer_host)

aws_advanced_python_wrapper/fastest_response_strategy_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
import time
1718
from copy import copy
1819
from dataclasses import dataclass
19-
from datetime import datetime
2020
from threading import Event, Lock, Thread
2121
from time import sleep
2222
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional,
@@ -96,7 +96,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op
9696

9797
# Found a fastest host. Let's find it in the latest topology.
9898
for host in self._plugin_service.hosts:
99-
if host == fastest_response_host:
99+
if host.get_host_and_port() == fastest_response_host.get_host_and_port():
100100
# found the fastest host in the topology
101101
return host
102102
# It seems that the fastest cached host isn't in the latest topology.
@@ -196,7 +196,7 @@ def close(self):
196196
logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host)
197197

198198
def _get_current_time(self):
199-
return datetime.now().microsecond / 1000 # milliseconds
199+
return time.perf_counter() * 1000 # milliseconds
200200

201201
def run(self):
202202
context: TelemetryContext = self._telemetry_factory.open_telemetry_context(

0 commit comments

Comments
 (0)