Skip to content

Commit 5dc424b

Browse files
sophia-bqkarenc-bq
andauthored
feat: srw (#1048)
Co-authored-by: Karen <64801825+karenc-bq@users.noreply.github.com>
1 parent d1c428a commit 5dc424b

10 files changed

Lines changed: 2274 additions & 651 deletions

aws_advanced_python_wrapper/plugin_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
from aws_advanced_python_wrapper.plugin import CanReleaseResources
8383
from aws_advanced_python_wrapper.read_write_splitting_plugin import \
8484
ReadWriteSplittingPluginFactory
85+
from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \
86+
SimpleReadWriteSplittingPluginFactory
8587
from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory
8688
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
8789
from aws_advanced_python_wrapper.utils.decorators import \
@@ -759,6 +761,7 @@ class PluginManager(CanReleaseResources):
759761
"host_monitoring_v2": HostMonitoringV2PluginFactory,
760762
"failover": FailoverPluginFactory,
761763
"read_write_splitting": ReadWriteSplittingPluginFactory,
764+
"srw": SimpleReadWriteSplittingPluginFactory,
762765
"fastest_response_strategy": FastestResponseStrategyPluginFactory,
763766
"stale_dns": StaleDnsPluginFactory,
764767
"custom_endpoint": CustomEndpointPluginFactory,
@@ -783,6 +786,7 @@ class PluginManager(CanReleaseResources):
783786
AuroraConnectionTrackerPluginFactory: 100,
784787
StaleDnsPluginFactory: 200,
785788
ReadWriteSplittingPluginFactory: 300,
789+
SimpleReadWriteSplittingPluginFactory: 310,
786790
FailoverPluginFactory: 400,
787791
HostMonitoringPluginFactory: 500,
788792
HostMonitoringV2PluginFactory: 510,

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 488 additions & 162 deletions
Large diffs are not rendered by default.

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ OpenTelemetryFactory.WrongParameterType="[OpenTelemetryFactory] Wrong parameter
286286

287287
Plugin.UnsupportedMethod=[Plugin] '{}' is not supported by this plugin.
288288

289-
PluginManager.ConfigurationProfileNotFound=PluginManager] Configuration profile '{}' not found.
289+
PluginManager.ConfigurationProfileNotFound=[PluginManager] Configuration profile '{}' not found.
290290
PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'.
291291
PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly.
292292
PluginManager.PipelineNone=[PluginManager] A pipeline was requested but the created pipeline evaluated to None.
@@ -357,8 +357,9 @@ ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole=[ReadWriteSplittingPl
357357
ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand=[ReadWriteSplittingPlugin] Detected an exception while executing a command: '{}'
358358
ReadWriteSplittingPlugin.ExecutingAgainstOldConnection=[ReadWriteSplittingPlugin] Executing method against old connection: '{}'
359359
ReadWriteSplittingPlugin.FailedToConnectToReader=[ReadWriteSplittingPlugin] Failed to connect to reader host: '{}'
360+
ReadWriteSplittingPlugin.FailedToConnectToWriter=[ReadWriteSplittingPlugin] Failed to connect to writer host: '{}'
360361
ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand=[ReadWriteSplittingPlugin] Detected a failover exception while executing a command: '{}'
361-
ReadWriteSplittingPlugin.FallbackToWriter=[ReadWriteSplittingPlugin] Failed to switch to a reader; the current writer will be used as a fallback: '{}'
362+
ReadWriteSplittingPlugin.FallbackToCurrentConnection=[ReadWriteSplittingPlugin] Failed to switch to a reader; the current connection will be used as a fallback: '{}'
362363
ReadWriteSplittingPlugin.NoReadersAvailable=[ReadWriteSplittingPlugin] The plugin was unable to establish a reader connection to any reader instance.
363364
ReadWriteSplittingPlugin.NoReadersFound=[ReadWriteSplittingPlugin] A reader instance was requested via set_read_only, but there are no readers in the host list. The current writer will be used as a fallback: '{}'
364365
ReadWriteSplittingPlugin.NoWriterFound=[ReadWriteSplittingPlugin] No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts.
@@ -382,6 +383,9 @@ RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector
382383
WeightedRandomHostSelector.WeightedRandomInvalidHostWeightPairs= [WeightedRandomHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of <host>:<weight>. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}'
383384
WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1.
384385

386+
SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter=[SimpleReadWriteSplittingPlugin] Configuration parameter {} is required.
387+
SimpleReadWriteSplittingPlugin.IncorrectConfiguration=[SimpleReadWriteSplittingPlugin] Unable to verify connections with this current configuration. Ensure a correct value is provided to the configuration parameter {}.
388+
385389
SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None.
386390
SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties.
387391

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from time import perf_counter_ns, sleep
18+
from typing import TYPE_CHECKING, Callable, Optional, Type, TypeVar
19+
20+
from aws_advanced_python_wrapper.host_availability import HostAvailability
21+
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
22+
ReadWriteConnectionHandler, ReadWriteSplittingConnectionManager)
23+
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
24+
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
25+
26+
if TYPE_CHECKING:
27+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
28+
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
29+
from aws_advanced_python_wrapper.pep249 import Connection
30+
from aws_advanced_python_wrapper.plugin_service import PluginService
31+
from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperty
32+
33+
from aws_advanced_python_wrapper.errors import AwsWrapperError
34+
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
35+
from aws_advanced_python_wrapper.plugin import PluginFactory
36+
from aws_advanced_python_wrapper.utils.messages import Messages
37+
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
38+
39+
40+
class EndpointBasedConnectionHandler(ReadWriteConnectionHandler):
41+
"""Endpoint based implementation of connection handling logic."""
42+
43+
def __init__(self, plugin_service: PluginService, props: Properties):
44+
read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
45+
WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True
46+
)
47+
write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
48+
WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True
49+
)
50+
51+
self._verify_new_connections: bool = EndpointBasedConnectionHandler._verify_parameter(
52+
WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS, props, bool
53+
)
54+
55+
if self._verify_new_connections:
56+
self._connect_retry_timeout_ms: int = EndpointBasedConnectionHandler._verify_parameter(
57+
WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS, props, int, lambda x: x > 0
58+
)
59+
self._connect_retry_interval_ms: int = EndpointBasedConnectionHandler._verify_parameter(
60+
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0
61+
)
62+
63+
self._verify_initial_connection_type: Optional[HostRole] = (
64+
EndpointBasedConnectionHandler._parse_role(
65+
WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.get(props)
66+
)
67+
)
68+
69+
self._plugin_service: PluginService = plugin_service
70+
self._rds_utils: RdsUtils = RdsUtils()
71+
self._host_list_provider_service: Optional[HostListProviderService] = None
72+
self._write_endpoint_host_info: HostInfo = self._create_host_info(write_endpoint, HostRole.WRITER)
73+
self._read_endpoint_host_info: HostInfo = self._create_host_info(read_endpoint, HostRole.READER)
74+
self._write_endpoint = write_endpoint.casefold()
75+
self._read_endpoint = read_endpoint.casefold()
76+
77+
@property
78+
def host_list_provider_service(self) -> Optional[HostListProviderService]:
79+
return self._host_list_provider_service
80+
81+
@host_list_provider_service.setter
82+
def host_list_provider_service(self, new_value: HostListProviderService) -> None:
83+
self._host_list_provider_service = new_value
84+
85+
def open_new_writer_connection(
86+
self,
87+
plugin_service_connect_func: Callable[[HostInfo], Connection],
88+
) -> tuple[Optional[Connection], Optional[HostInfo]]:
89+
if self._verify_new_connections:
90+
return self._get_verified_connection(self._write_endpoint_host_info, HostRole.WRITER, plugin_service_connect_func), \
91+
self._write_endpoint_host_info
92+
93+
return plugin_service_connect_func(self._write_endpoint_host_info), self._write_endpoint_host_info
94+
95+
def open_new_reader_connection(
96+
self,
97+
plugin_service_connect_func: Callable[[HostInfo], Connection],
98+
) -> tuple[Optional[Connection], Optional[HostInfo]]:
99+
if self._verify_new_connections:
100+
return self._get_verified_connection(self._read_endpoint_host_info, HostRole.READER, plugin_service_connect_func), \
101+
self._read_endpoint_host_info
102+
103+
return plugin_service_connect_func(self._read_endpoint_host_info), self._read_endpoint_host_info
104+
105+
def get_verified_initial_connection(
106+
self,
107+
host_info: HostInfo,
108+
is_initial_connection: bool,
109+
plugin_service_connect_func: Callable[[HostInfo], Connection],
110+
connect_func: Callable,
111+
) -> Connection:
112+
if not is_initial_connection or not self._verify_new_connections:
113+
return connect_func()
114+
115+
url_type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host)
116+
117+
conn: Optional[Connection] = None
118+
119+
if (
120+
url_type == RdsUrlType.RDS_WRITER_CLUSTER
121+
or self._verify_initial_connection_type == HostRole.WRITER
122+
):
123+
conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func)
124+
elif (
125+
url_type == RdsUrlType.RDS_READER_CLUSTER
126+
or self._verify_initial_connection_type == HostRole.READER
127+
):
128+
conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func)
129+
130+
if conn is None:
131+
conn = connect_func()
132+
133+
self._set_initial_connection_host_info(host_info)
134+
return conn
135+
136+
def _set_initial_connection_host_info(self, host_info: HostInfo):
137+
if self._host_list_provider_service is None:
138+
return
139+
140+
self._host_list_provider_service.initial_connection_host_info = host_info
141+
142+
def _get_verified_connection(
143+
self,
144+
host_info: HostInfo,
145+
role: HostRole,
146+
plugin_service_connect_func: Callable[[HostInfo], Connection],
147+
connect_func: Optional[Callable] = None,
148+
) -> Optional[Connection]:
149+
end_time_nano = perf_counter_ns() + (self._connect_retry_timeout_ms * 1000000)
150+
151+
candidate_conn: Optional[Connection]
152+
153+
while perf_counter_ns() < end_time_nano:
154+
candidate_conn = None
155+
156+
try:
157+
if connect_func is not None:
158+
candidate_conn = connect_func()
159+
else:
160+
candidate_conn = plugin_service_connect_func(host_info)
161+
162+
if candidate_conn is None:
163+
self._delay()
164+
continue
165+
166+
actual_role = self._plugin_service.get_host_role(candidate_conn)
167+
168+
if actual_role != role:
169+
ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect)
170+
self._delay()
171+
continue
172+
173+
return candidate_conn
174+
175+
except Exception:
176+
ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect)
177+
self._delay()
178+
179+
return None
180+
181+
def can_host_be_used(self, host_info: HostInfo) -> bool:
182+
# Assume that the host can always be used, no topology-based information to check.
183+
return True
184+
185+
def has_no_readers(self) -> bool:
186+
# SetReadOnly(true) will always connect to the read_endpoint, regardless of number of readers.
187+
return False
188+
189+
def refresh_and_store_host_list(
190+
self, current_conn: Optional[Connection], driver_dialect: DriverDialect
191+
):
192+
# Endpoint based connections do not require a host list.
193+
return
194+
195+
def should_update_writer_with_current_conn(
196+
self, current_conn: Connection, current_host: HostInfo, writer_conn: Connection
197+
) -> bool:
198+
return (
199+
self.is_writer_host(current_host)
200+
and current_conn != writer_conn
201+
and (
202+
not self._verify_new_connections
203+
or self._plugin_service.get_host_role(current_conn) == HostRole.WRITER
204+
)
205+
)
206+
207+
def should_update_reader_with_current_conn(
208+
self, current_conn: Connection, current_host: HostInfo, reader_conn: Connection
209+
) -> bool:
210+
return (
211+
self.is_reader_host(current_host)
212+
and current_conn != reader_conn
213+
and (
214+
not self._verify_new_connections
215+
or self._plugin_service.get_host_role(current_conn) == HostRole.READER
216+
)
217+
)
218+
219+
def is_writer_host(self, current_host: HostInfo) -> bool:
220+
return (
221+
current_host.host.casefold() == self._write_endpoint
222+
or current_host.url.casefold() == self._write_endpoint
223+
)
224+
225+
def is_reader_host(self, current_host: HostInfo) -> bool:
226+
return (
227+
current_host.host.casefold() == self._read_endpoint
228+
or current_host.url.casefold() == self._read_endpoint
229+
)
230+
231+
def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
232+
endpoint = endpoint.strip()
233+
host = endpoint
234+
try:
235+
port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \
236+
else self._plugin_service.current_host_info.port
237+
except AwsWrapperError: # if current_host_info cannot be determined fallback to default port
238+
port = self._plugin_service.database_dialect.default_port
239+
colon_index = endpoint.rfind(":")
240+
241+
if colon_index != -1:
242+
host = endpoint[:colon_index]
243+
port_str = endpoint[colon_index + 1:]
244+
if port_str.isdigit():
245+
port = int(port_str)
246+
247+
return HostInfo(
248+
host=host, port=port, role=role, availability=HostAvailability.AVAILABLE
249+
)
250+
251+
T = TypeVar('T')
252+
253+
@staticmethod
254+
def _verify_parameter(prop: WrapperProperty, props: Properties, expected_type: Type[T], validator=None, required=False):
255+
value = prop.get_type(props, expected_type)
256+
if required:
257+
if value is None:
258+
raise AwsWrapperError(
259+
Messages.get_formatted(
260+
"SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter",
261+
prop.name,
262+
)
263+
)
264+
265+
if validator and not validator(value):
266+
raise ValueError(
267+
Messages.get_formatted(
268+
"SimpleReadWriteSplittingPlugin.IncorrectConfiguration",
269+
prop.name,
270+
)
271+
)
272+
return value
273+
274+
def _delay(self):
275+
sleep(self._connect_retry_interval_ms / 1000)
276+
277+
@staticmethod
278+
def _parse_role(role_str: Optional[str]) -> HostRole:
279+
if not role_str:
280+
return HostRole.UNKNOWN
281+
282+
phase_lower = role_str.lower()
283+
if phase_lower == "reader":
284+
return HostRole.READER
285+
elif phase_lower == "writer":
286+
return HostRole.WRITER
287+
else:
288+
raise ValueError(
289+
Messages.get_formatted(
290+
"SimpleReadWriteSplittingPlugin.IncorrectConfiguration",
291+
WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.name,
292+
)
293+
)
294+
295+
296+
class SimpleReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager):
297+
def __init__(self, plugin_service: PluginService, props: Properties):
298+
# The simple read/write splitting plugin handles connections based on configuration parameter endpoints.
299+
connection_handler = EndpointBasedConnectionHandler(plugin_service, props)
300+
301+
super().__init__(plugin_service, props, connection_handler)
302+
303+
304+
class SimpleReadWriteSplittingPluginFactory(PluginFactory):
305+
def get_instance(self, plugin_service, props: Properties):
306+
return SimpleReadWriteSplittingPlugin(plugin_service, props)

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def _create_sql_alchemy_pool(self, **kwargs):
173173

174174
def release_resources(self):
175175
for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
176-
cache_item.item.dispose()
176+
try:
177+
cache_item.item.dispose()
178+
except Exception:
179+
# Swallow exception, connections may already be dead
180+
pass
177181
SqlAlchemyPooledConnectionProvider._database_pools.clear()
178182

179183

0 commit comments

Comments
 (0)