Skip to content

Commit 6d5216d

Browse files
authored
fix: properly handle nested errors (#1092)
1 parent c1e33f9 commit 6d5216d

29 files changed

Lines changed: 435 additions & 100 deletions

aws_advanced_python_wrapper/aws_secrets_manager_plugin.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from aws_advanced_python_wrapper.pep249 import Connection
3333
from aws_advanced_python_wrapper.plugin_service import PluginService
3434

35-
from aws_advanced_python_wrapper.errors import AwsWrapperError
35+
from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError
3636
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
3737
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
3838
from aws_advanced_python_wrapper.utils.log import Logger
@@ -113,9 +113,12 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
113113
return connect_func()
114114

115115
except Exception as e:
116+
if self._plugin_service.is_network_exception(error=e):
117+
raise AwsConnectError(Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e
118+
116119
if not self._plugin_service.is_login_exception(error=e) or secret_fetched:
117120
raise AwsWrapperError(
118-
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e
121+
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e), e) from e
119122

120123
secret_fetched = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns, force_refetch=True)
121124

@@ -126,8 +129,8 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
126129
except Exception as unhandled_error:
127130
raise AwsWrapperError(
128131
Messages.get_formatted("AwsSecretsManagerPlugin.UnhandledException",
129-
unhandled_error)) from unhandled_error
130-
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e
132+
unhandled_error), unhandled_error) from unhandled_error
133+
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e), e) from e
131134

132135
def _update_secret(self, host_info: HostInfo, props: Properties, token_expiration_ns: int, force_refetch: bool = False) -> bool:
133136
"""
@@ -154,19 +157,19 @@ def _update_secret(self, host_info: HostInfo, props: Properties, token_expiratio
154157
except (ClientError, AttributeError) as e:
155158
logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)
156159
raise AwsWrapperError(
157-
Messages.get_formatted("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)) from e
160+
Messages.get_formatted("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e), e) from e
158161
except JSONDecodeError as e:
159162
logger.debug("AwsSecretsManagerPlugin.JsonDecodeError", e)
160163
raise AwsWrapperError(
161-
Messages.get_formatted("AwsSecretsManagerPlugin.JsonDecodeError", e))
162-
except EndpointConnectionError:
164+
Messages.get_formatted("AwsSecretsManagerPlugin.JsonDecodeError", e), e) from e
165+
except EndpointConnectionError as e:
163166
logger.debug("AwsSecretsManagerPlugin.EndpointOverrideInvalidConnection", endpoint)
164167
raise AwsWrapperError(
165-
Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideInvalidConnection", endpoint))
166-
except ValueError:
168+
Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideInvalidConnection", endpoint), e) from e
169+
except ValueError as e:
167170
logger.debug("AwsSecretsManagerPlugin.EndpointOverrideMisconfigured", endpoint)
168171
raise AwsWrapperError(
169-
Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideMisconfigured", endpoint))
172+
Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideMisconfigured", endpoint), e) from e
170173

171174
return fetched
172175
except Exception as ex:

aws_advanced_python_wrapper/custom_endpoint_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def _wait_for_info(self, monitor: CustomEndpointMonitor):
330330
while not has_info and perf_counter_ns() < wait_for_info_timeout_ns:
331331
sleep(0.1)
332332
has_info = monitor.has_custom_endpoint_info()
333-
except InterruptedError:
334-
raise AwsWrapperError(Messages.get_formatted("CustomEndpointPlugin.InterruptedThread", hostname))
333+
except InterruptedError as e:
334+
raise AwsWrapperError(Messages.get_formatted("CustomEndpointPlugin.InterruptedThread", hostname), e) from e
335335

336336
if not has_info:
337337
raise AwsWrapperError(

aws_advanced_python_wrapper/database_dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class DialectCode(Enum):
7575
def from_string(value: str) -> DialectCode:
7676
try:
7777
return DialectCode(value)
78-
except ValueError:
79-
raise AwsWrapperError(Messages.get_formatted("DialectCode.InvalidStringValue", value))
78+
except ValueError as e:
79+
raise AwsWrapperError(Messages.get_formatted("DialectCode.InvalidStringValue", value), e) from e
8080

8181

8282
class TargetDriverType(Enum):

aws_advanced_python_wrapper/errors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
from .pep249 import Error
1618

1719

1820
class AwsWrapperError(Error):
1921
__module__ = "aws_advanced_python_wrapper"
22+
driver_error: Optional[Exception]
23+
24+
def __init__(self, message: str = "", original_error: Optional[Exception] = None):
25+
super().__init__(message)
26+
# If wrapping another AwsWrapperError, preserve the original driver exception
27+
if isinstance(original_error, AwsWrapperError) and original_error.driver_error is not None:
28+
self.driver_error = original_error.driver_error
29+
else:
30+
self.driver_error = original_error
2031

2132

2233
class UnsupportedOperationError(AwsWrapperError):
@@ -45,3 +56,7 @@ class FailoverSuccessError(FailoverError):
4556

4657
class ReadWriteSplittingError(AwsWrapperError):
4758
__module__ = "aws_advanced_python_wrapper"
59+
60+
61+
class AwsConnectError(AwsWrapperError):
62+
__module__ = "aws_advanced_python_wrapper"

aws_advanced_python_wrapper/failover_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:
175175

176176
self._pick_new_connection()
177177
self._last_exception = ex
178-
raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(ex))) from ex
178+
raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(ex)), ex) from ex
179179

180180
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
181181
if not self._enable_failover_setting:

aws_advanced_python_wrapper/failover_v2_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _deal_with_original_exception(self, original_exception: Exception) -> None:
204204
self._pick_new_connection()
205205
self._last_exception_dealt_with = original_exception
206206

207-
raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(original_exception))) \
207+
raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(original_exception)), original_exception) \
208208
from original_exception
209209

210210
def _failover(self) -> None:
@@ -339,6 +339,7 @@ def _failover_writer(self) -> None:
339339
LogUtils.log_topology(updated_hosts)))
340340

341341
logger.info("FailoverPlugin.FoundWriterCandidate", writer_candidate)
342+
342343
allowed_hosts = self._plugin_service.hosts
343344
if not any(host.host == writer_candidate.host and host.port == writer_candidate.port
344345
for host in allowed_hosts):

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import requests
4141

42-
from aws_advanced_python_wrapper.errors import AwsWrapperError
42+
from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError
4343
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
4444
from aws_advanced_python_wrapper.utils.log import Logger
4545
from aws_advanced_python_wrapper.utils.messages import Messages
@@ -114,8 +114,11 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
114114
try:
115115
return connect_func()
116116
except Exception as e:
117+
if self._plugin_service.is_network_exception(e):
118+
raise AwsConnectError(Messages.get_formatted("FederatedAuthPlugin.ConnectException", e)) from e
119+
117120
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
118-
raise e
121+
raise AwsWrapperError(Messages.get_formatted("FederatedAuthPlugin.ConnectException", e), e) from e
119122

120123
self._update_authentication_token(token_host_info, props, user, region, cache_key)
121124

@@ -124,7 +127,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
124127
except Exception as e:
125128
error_message = "FederatedAuthPlugin.UnhandledException"
126129
logger.debug(error_message, e)
127-
raise AwsWrapperError(Messages.get_formatted(error_message, e)) from e
130+
raise AwsWrapperError(Messages.get_formatted(error_message, e), e) from e
128131

129132
def force_connect(
130133
self,
@@ -211,7 +214,7 @@ def get_saml_assertion(self, props: Properties):
211214
except IOError as e:
212215
error_message = "FederatedAuthPlugin.UnhandledException"
213216
logger.debug(error_message, e)
214-
raise AwsWrapperError(Messages.get_formatted(error_message, e))
217+
raise AwsWrapperError(Messages.get_formatted(error_message, e), e) from e
215218

216219
def _get_sign_in_page_body(self, url: str, props: Properties) -> str:
217220
logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url)

aws_advanced_python_wrapper/host_list_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
641641
cursor.execute(self._dialect.topology_query)
642642
return self._process_query_results(cursor)
643643
except ProgrammingError as e:
644-
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery")) from e
644+
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e
645645

646646
def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]:
647647
"""
@@ -707,7 +707,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
707707
cursor.execute(self._dialect.topology_query)
708708
return self._process_multi_az_query_results(cursor, writer_id)
709709
except ProgrammingError as e:
710-
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery")) from e
710+
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e
711711

712712
def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Tuple[HostInfo, ...]:
713713
hosts_dict = {}

aws_advanced_python_wrapper/host_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def _update_host_weight_map_from_string(self, props: Optional[Properties] = None
257257
raise AwsWrapperError(Messages.get_formatted(message, pair))
258258

259259
self._host_weight_map[host_name] = weight
260-
except ValueError:
260+
except ValueError as e:
261261
logger.error(message, pair)
262-
raise AwsWrapperError(Messages.get_formatted(message, pair))
262+
raise AwsWrapperError(Messages.get_formatted(message, pair), e) from e
263263

264264

265265
class HighestWeightHostSelector(HostSelector):

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from datetime import datetime, timedelta
3232
from typing import Callable, Dict, Set
3333

34-
from aws_advanced_python_wrapper.errors import AwsWrapperError
34+
from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError
3535
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
3636
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
3737
from aws_advanced_python_wrapper.utils.log import Logger
@@ -125,10 +125,12 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
125125
except Exception as e:
126126
logger.debug("IamAuthPlugin.ConnectException", e)
127127

128+
if self._plugin_service.is_network_exception(error=e):
129+
raise AwsConnectError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e
130+
128131
is_cached_token = (token_info is not None and not token_info.is_expired())
129132
if not self._plugin_service.is_login_exception(error=e) or not is_cached_token:
130-
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e
131-
133+
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e), e) from e
132134
# Login unsuccessful with cached token
133135
# Try to generate a new token and try to connect again
134136
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
@@ -145,7 +147,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
145147
try:
146148
return connect_func()
147149
except Exception as e:
148-
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.UnhandledException", e)) from e
150+
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.UnhandledException", e), e) from e
149151

150152
def force_connect(
151153
self,

0 commit comments

Comments
 (0)