Skip to content

Commit f8f5acf

Browse files
fix: failover_v2 segfaulting when using MySQL C (#1088)
1 parent b10a3a2 commit f8f5acf

3 files changed

Lines changed: 313 additions & 11 deletions

File tree

aws_advanced_python_wrapper/cluster_topology_monitor.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from aws_advanced_python_wrapper.utils.atomic import AtomicReference
2525
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
2626
from aws_advanced_python_wrapper.utils.messages import Messages
27+
from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \
28+
ThreadSafeConnectionHolder
2729
from aws_advanced_python_wrapper.utils.utils import LogUtils
2830

2931
if TYPE_CHECKING:
@@ -86,7 +88,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
8688
self._high_refresh_rate_nano = high_refresh_rate_nano
8789

8890
self._writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None)
89-
self._monitoring_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
91+
self._monitoring_connection: ThreadSafeConnectionHolder = ThreadSafeConnectionHolder(None)
9092

9193
self._topology_updated = threading.Event()
9294
self._request_to_update_topology = threading.Event()
@@ -123,7 +125,7 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H
123125
return current_hosts
124126

125127
if should_verify_writer:
126-
self._close_connection_from_ref(self._monitoring_connection)
128+
self._monitoring_connection.clear()
127129
self._is_verified_writer_connection = False
128130

129131
result = self._wait_till_topology_gets_updated(timeout_sec)
@@ -177,7 +179,7 @@ def close(self) -> None:
177179
self._monitor_thread.join(self.MONITOR_TERMINATION_TIMEOUT_SEC)
178180

179181
# Step 3: Now safe to close connections - no threads are using them
180-
self._close_connection_from_ref(self._monitoring_connection)
182+
self._monitoring_connection.clear()
181183
self._close_connection_from_ref(self._host_threads_writer_connection)
182184
self._close_connection_from_ref(self._host_threads_reader_connection)
183185

@@ -220,8 +222,8 @@ def _monitor(self) -> None:
220222
writer_connection = self._host_threads_writer_connection.get()
221223
if (writer_connection is not None and writer_host_info is not None):
222224
logger.debug("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host)
223-
self._close_connection_from_ref(self._monitoring_connection)
224-
self._monitoring_connection.set(writer_connection)
225+
# Transfer the writer connection to monitoring connection
226+
self._monitoring_connection.set(writer_connection, close_previous=True)
225227
self._writer_host_info.set(writer_host_info)
226228
self._is_verified_writer_connection = True
227229
self._high_refresh_rate_end_time_nano = (
@@ -259,9 +261,9 @@ def _monitor(self) -> None:
259261
self._close_host_monitors()
260262
self._submitted_hosts.clear()
261263

262-
hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get())
264+
hosts = self._fetch_topology_and_update_cache_safe()
263265
if not hosts:
264-
self._close_connection_from_ref(self._monitoring_connection)
266+
self._monitoring_connection.clear()
265267
self._is_verified_writer_connection = False
266268
self._writer_host_info.set(None)
267269
continue
@@ -282,7 +284,7 @@ def _monitor(self) -> None:
282284
finally:
283285
self._stop.set()
284286
self._close_host_monitors()
285-
self._close_connection_from_ref(self._monitoring_connection)
287+
self._monitoring_connection.clear()
286288
logger.debug("ClusterTopologyMonitor.StopMonitoringThread", self._cluster_id, self._initial_host_info.host)
287289

288290
def _is_in_panic_mode(self) -> bool:
@@ -297,7 +299,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
297299
# Try to connect to the initial host first
298300
try:
299301
conn = self._plugin_service.force_connect(self._initial_host_info, self._monitoring_properties)
300-
self._monitoring_connection.set(conn)
302+
self._monitoring_connection.set(conn, close_previous=False)
301303
logger.debug("ClusterTopologyMonitorImpl.OpenedMonitoringConnection", self._cluster_id, self._initial_host_info.host)
302304

303305
try:
@@ -313,7 +315,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
313315
except Exception:
314316
return ()
315317

316-
hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get())
318+
hosts = self._fetch_topology_and_update_cache_safe()
317319
if writer_verified_by_this_thread:
318320
if self._ignore_new_topology_requests_end_time_nano == -1:
319321
self._ignore_new_topology_requests_end_time_nano = 0
@@ -322,7 +324,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
322324
time.time_ns() + self.IGNORE_TOPOLOGY_REQUEST_NANO)
323325

324326
if len(hosts) == 0:
325-
self._close_connection_from_ref(self._monitoring_connection)
327+
self._monitoring_connection.clear()
326328
self._is_verified_writer_connection = False
327329
self._writer_host_info.set(None)
328330

@@ -400,6 +402,16 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) ->
400402
logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex)
401403
return ()
402404

405+
def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
406+
"""
407+
Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions.
408+
The lock is held during the entire query operation.
409+
"""
410+
result = self._monitoring_connection.use_connection(
411+
lambda conn: self._fetch_topology_and_update_cache(conn)
412+
)
413+
return result if result is not None else ()
414+
403415
def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]:
404416
hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect)
405417
if hosts is not None:
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from __future__ import annotations
15+
16+
from threading import RLock
17+
from typing import TYPE_CHECKING, Optional
18+
19+
if TYPE_CHECKING:
20+
from aws_advanced_python_wrapper.pep249 import Connection
21+
22+
from aws_advanced_python_wrapper.utils.log import Logger
23+
24+
logger = Logger(__name__)
25+
26+
27+
class ThreadSafeConnectionHolder:
28+
"""
29+
Thread-safe connection container that ensures connections are properly
30+
closed when replaced or cleared. This class prevents race conditions where
31+
one thread might close a connection while another thread is using it.
32+
"""
33+
34+
def __init__(self, initial_connection: Optional[Connection] = None):
35+
self._connection: Optional[Connection] = initial_connection
36+
self._lock: RLock = RLock()
37+
38+
def get(self) -> Optional[Connection]:
39+
with self._lock:
40+
return self._connection
41+
42+
def set(self, new_connection: Optional[Connection], close_previous: bool = True) -> None:
43+
with self._lock:
44+
old_connection = self._connection
45+
self._connection = new_connection
46+
47+
if close_previous and old_connection is not None and old_connection != new_connection:
48+
self._close_connection(old_connection)
49+
50+
def get_and_set(self, new_connection: Optional[Connection], close_previous: bool = True) -> Optional[Connection]:
51+
with self._lock:
52+
old_connection = self._connection
53+
self._connection = new_connection
54+
55+
if close_previous and old_connection is not None and old_connection != new_connection:
56+
self._close_connection(old_connection)
57+
58+
return old_connection
59+
60+
def compare_and_set(
61+
self,
62+
expected_connection: Optional[Connection],
63+
new_connection: Optional[Connection],
64+
close_previous: bool = True
65+
) -> bool:
66+
with self._lock:
67+
if self._connection == expected_connection:
68+
old_connection = self._connection
69+
self._connection = new_connection
70+
71+
if close_previous and old_connection is not None and old_connection != new_connection:
72+
self._close_connection(old_connection)
73+
74+
return True
75+
return False
76+
77+
def clear(self) -> None:
78+
self.set(None, close_previous=True)
79+
80+
def use_connection(self, func, *args, **kwargs):
81+
"""
82+
Safely use the connection within a locked context.
83+
84+
This method ensures the connection cannot be closed by another thread
85+
while the provided function is executing.
86+
87+
:param func: Function to call with the connection as the first argument.
88+
:param args: Additional positional arguments to pass to func.
89+
:param kwargs: Additional keyword arguments to pass to func.
90+
:return: The result of calling func, or None if no connection is available.
91+
92+
Example:
93+
result = holder.use_connection(lambda conn: conn.cursor().execute("SELECT 1"))
94+
"""
95+
with self._lock:
96+
if self._connection is None:
97+
return None
98+
return func(self._connection, *args, **kwargs)
99+
100+
def _close_connection(self, connection: Connection) -> None:
101+
try:
102+
if connection is not None:
103+
connection.close()
104+
except Exception: # ignore
105+
pass
106+
107+
def __repr__(self) -> str:
108+
with self._lock:
109+
return f"ThreadSafeConnectionHolder(connection={self._connection})"

0 commit comments

Comments
 (0)