2424from aws_advanced_python_wrapper .utils .atomic import AtomicReference
2525from aws_advanced_python_wrapper .utils .cache_map import CacheMap
2626from aws_advanced_python_wrapper .utils .messages import Messages
27+ from aws_advanced_python_wrapper .utils .thread_safe_connection_holder import \
28+ ThreadSafeConnectionHolder
2729from aws_advanced_python_wrapper .utils .utils import LogUtils
2830
2931if TYPE_CHECKING :
@@ -86,7 +88,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
8688 self ._high_refresh_rate_nano = high_refresh_rate_nano
8789
8890 self ._writer_host_info : AtomicReference [Optional [HostInfo ]] = AtomicReference (None )
89- self ._monitoring_connection : AtomicReference [ Optional [ Connection ]] = AtomicReference (None )
91+ self ._monitoring_connection : ThreadSafeConnectionHolder = ThreadSafeConnectionHolder (None )
9092
9193 self ._topology_updated = threading .Event ()
9294 self ._request_to_update_topology = threading .Event ()
@@ -123,7 +125,7 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H
123125 return current_hosts
124126
125127 if should_verify_writer :
126- self ._close_connection_from_ref ( self . _monitoring_connection )
128+ self ._monitoring_connection . clear ( )
127129 self ._is_verified_writer_connection = False
128130
129131 result = self ._wait_till_topology_gets_updated (timeout_sec )
@@ -177,7 +179,7 @@ def close(self) -> None:
177179 self ._monitor_thread .join (self .MONITOR_TERMINATION_TIMEOUT_SEC )
178180
179181 # Step 3: Now safe to close connections - no threads are using them
180- self ._close_connection_from_ref ( self . _monitoring_connection )
182+ self ._monitoring_connection . clear ( )
181183 self ._close_connection_from_ref (self ._host_threads_writer_connection )
182184 self ._close_connection_from_ref (self ._host_threads_reader_connection )
183185
@@ -220,8 +222,8 @@ def _monitor(self) -> None:
220222 writer_connection = self ._host_threads_writer_connection .get ()
221223 if (writer_connection is not None and writer_host_info is not None ):
222224 logger .debug ("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors" , self ._cluster_id , writer_host_info .host )
223- self . _close_connection_from_ref ( self . _monitoring_connection )
224- self ._monitoring_connection .set (writer_connection )
225+ # Transfer the writer connection to monitoring connection
226+ self ._monitoring_connection .set (writer_connection , close_previous = True )
225227 self ._writer_host_info .set (writer_host_info )
226228 self ._is_verified_writer_connection = True
227229 self ._high_refresh_rate_end_time_nano = (
@@ -259,9 +261,9 @@ def _monitor(self) -> None:
259261 self ._close_host_monitors ()
260262 self ._submitted_hosts .clear ()
261263
262- hosts = self ._fetch_topology_and_update_cache ( self . _monitoring_connection . get () )
264+ hosts = self ._fetch_topology_and_update_cache_safe ( )
263265 if not hosts :
264- self ._close_connection_from_ref ( self . _monitoring_connection )
266+ self ._monitoring_connection . clear ( )
265267 self ._is_verified_writer_connection = False
266268 self ._writer_host_info .set (None )
267269 continue
@@ -282,7 +284,7 @@ def _monitor(self) -> None:
282284 finally :
283285 self ._stop .set ()
284286 self ._close_host_monitors ()
285- self ._close_connection_from_ref ( self . _monitoring_connection )
287+ self ._monitoring_connection . clear ( )
286288 logger .debug ("ClusterTopologyMonitor.StopMonitoringThread" , self ._cluster_id , self ._initial_host_info .host )
287289
288290 def _is_in_panic_mode (self ) -> bool :
@@ -297,7 +299,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
297299 # Try to connect to the initial host first
298300 try :
299301 conn = self ._plugin_service .force_connect (self ._initial_host_info , self ._monitoring_properties )
300- self ._monitoring_connection .set (conn )
302+ self ._monitoring_connection .set (conn , close_previous = False )
301303 logger .debug ("ClusterTopologyMonitorImpl.OpenedMonitoringConnection" , self ._cluster_id , self ._initial_host_info .host )
302304
303305 try :
@@ -313,7 +315,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
313315 except Exception :
314316 return ()
315317
316- hosts = self ._fetch_topology_and_update_cache ( self . _monitoring_connection . get () )
318+ hosts = self ._fetch_topology_and_update_cache_safe ( )
317319 if writer_verified_by_this_thread :
318320 if self ._ignore_new_topology_requests_end_time_nano == - 1 :
319321 self ._ignore_new_topology_requests_end_time_nano = 0
@@ -322,7 +324,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
322324 time .time_ns () + self .IGNORE_TOPOLOGY_REQUEST_NANO )
323325
324326 if len (hosts ) == 0 :
325- self ._close_connection_from_ref ( self . _monitoring_connection )
327+ self ._monitoring_connection . clear ( )
326328 self ._is_verified_writer_connection = False
327329 self ._writer_host_info .set (None )
328330
@@ -400,6 +402,16 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) ->
400402 logger .debug ("ClusterTopologyMonitorImpl.ErrorFetchingTopology" , self ._cluster_id , ex )
401403 return ()
402404
405+ def _fetch_topology_and_update_cache_safe (self ) -> Tuple [HostInfo , ...]:
406+ """
407+ Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions.
408+ The lock is held during the entire query operation.
409+ """
410+ result = self ._monitoring_connection .use_connection (
411+ lambda conn : self ._fetch_topology_and_update_cache (conn )
412+ )
413+ return result if result is not None else ()
414+
403415 def _query_for_topology (self , connection : Connection ) -> Tuple [HostInfo , ...]:
404416 hosts = self ._topology_utils .query_for_topology (connection , self ._plugin_service .driver_dialect )
405417 if hosts is not None :
0 commit comments