Skip to content

Commit 6f195d2

Browse files
authored
chore: performance optimization for auth plugins by caching clients and sessions (#1084)
1 parent 70f6931 commit 6f195d2

18 files changed

Lines changed: 1312 additions & 303 deletions

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ jobs:
2323
markdown-link-check:
2424
runs-on: ubuntu-latest
2525
steps:
26-
- uses: actions/checkout@master
27-
- uses: gaurav-nelson/github-action-markdown-link-check@v1
26+
- uses: actions/checkout@v4
27+
- uses: tcort/github-action-markdown-link-check@v1
2828
with:
2929
use-quiet-mode: 'yes'
3030
folder-path: 'docs'
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from threading import Lock
18+
from typing import TYPE_CHECKING, Any, Optional
19+
20+
from boto3 import Session
21+
22+
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
23+
24+
if TYPE_CHECKING:
25+
from aws_advanced_python_wrapper.hostinfo import HostInfo
26+
from aws_advanced_python_wrapper.utils.properties import Properties
27+
28+
29+
class AwsCredentialsManager:
30+
_lock = Lock()
31+
_sessions: dict[str, Session] = {}
32+
_clients: dict[str, Any] = {}
33+
34+
@staticmethod
35+
def get_session(host_info: HostInfo, props: Properties, region: str) -> Session:
36+
profile_name = WrapperProperties.AWS_PROFILE.get(props)
37+
host_key = f'{host_info.as_alias()}{region}{profile_name}'
38+
39+
with AwsCredentialsManager._lock:
40+
if host_key in AwsCredentialsManager._sessions:
41+
return AwsCredentialsManager._sessions[host_key]
42+
43+
# Initialize session outside of lock.
44+
session = Session(profile_name=profile_name, region_name=region) if profile_name else Session(region_name=region)
45+
46+
with AwsCredentialsManager._lock:
47+
if host_key not in AwsCredentialsManager._sessions:
48+
AwsCredentialsManager._sessions[host_key] = session
49+
return AwsCredentialsManager._sessions[host_key]
50+
51+
@staticmethod
52+
def get_client(service_name: str, session: Session, host: Optional[str], region: Optional[str], endpoint_url: Optional[str] = None):
53+
key = f'{host}{region}{service_name}{endpoint_url}'
54+
55+
with AwsCredentialsManager._lock:
56+
if key in AwsCredentialsManager._clients:
57+
return AwsCredentialsManager._clients[key]
58+
59+
# Initialize client outside of lock.
60+
if endpoint_url:
61+
client = session.client(service_name=service_name, endpoint_url=endpoint_url) # type: ignore[call-overload]
62+
else:
63+
client = session.client(service_name=service_name) # type: ignore[call-overload]
64+
65+
with AwsCredentialsManager._lock:
66+
if key not in AwsCredentialsManager._clients:
67+
AwsCredentialsManager._clients[key] = client
68+
return AwsCredentialsManager._clients[key]
69+
70+
@staticmethod
71+
def release_resources() -> None:
72+
with AwsCredentialsManager._lock:
73+
for key, client in AwsCredentialsManager._clients.items():
74+
client.close()
75+
AwsCredentialsManager._clients.clear()
76+
AwsCredentialsManager._sessions.clear()
77+
return None

aws_advanced_python_wrapper/aws_secrets_manager_plugin.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from types import SimpleNamespace
2020
from typing import TYPE_CHECKING, Callable, Optional, Set, Tuple
2121

22-
import boto3
2322
from botocore.exceptions import ClientError, EndpointConnectionError
2423

24+
from aws_advanced_python_wrapper.aws_credentials_manager import \
25+
AwsCredentialsManager
2526
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
2627

2728
if TYPE_CHECKING:
@@ -86,7 +87,7 @@ def connect(
8687
props: Properties,
8788
is_initial_connection: bool,
8889
connect_func: Callable) -> Connection:
89-
return self._connect(props, connect_func)
90+
return self._connect(host_info, props, connect_func)
9091

9192
def force_connect(
9293
self,
@@ -96,16 +97,16 @@ def force_connect(
9697
props: Properties,
9798
is_initial_connection: bool,
9899
force_connect_func: Callable) -> Connection:
99-
return self._connect(props, force_connect_func)
100+
return self._connect(host_info, props, force_connect_func)
100101

101-
def _connect(self, props: Properties, connect_func: Callable) -> Connection:
102+
def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
102103
token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
103104
# if value is less than 0, default to one year
104105
if token_expiration_sec < 0:
105106
token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS
106107
token_expiration_ns = token_expiration_sec * 1_000_000_000
107108

108-
secret_fetched: bool = self._update_secret(token_expiration_ns=token_expiration_ns)
109+
secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns)
109110

110111
try:
111112
self._apply_secret_to_properties(props)
@@ -116,7 +117,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
116117
raise AwsWrapperError(
117118
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e
118119

119-
secret_fetched = self._update_secret(token_expiration_ns=token_expiration_ns, force_refetch=True)
120+
secret_fetched = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns, force_refetch=True)
120121

121122
if secret_fetched:
122123
try:
@@ -128,7 +129,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
128129
unhandled_error)) from unhandled_error
129130
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e
130131

131-
def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) -> bool:
132+
def _update_secret(self, host_info: HostInfo, props: Properties, token_expiration_ns: int, force_refetch: bool = False) -> bool:
132133
"""
133134
Called to update credentials from the cache, or from the AWS Secrets Manager service.
134135
:param token_expiration_ns: Expiration time in nanoseconds for secret stored in cache.
@@ -146,7 +147,7 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)
146147
endpoint = self._secret_key[2]
147148
if not self._secret or force_refetch:
148149
try:
149-
self._secret = self._fetch_latest_credentials()
150+
self._secret = self._fetch_latest_credentials(host_info, props)
150151
if self._secret:
151152
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
152153
fetched = True
@@ -177,26 +178,19 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)
177178
if context is not None:
178179
context.close_context()
179180

180-
def _fetch_latest_credentials(self):
181+
def _fetch_latest_credentials(self, host_info: HostInfo, props: Properties):
181182
"""
182183
Fetches the current credentials from AWS Secrets Manager service.
183184
184185
:return: a Secret object containing the credentials fetched from the AWS Secrets Manager service.
185186
"""
186-
session = self._session if self._session else boto3.Session()
187-
188-
client = session.client(
189-
'secretsmanager',
190-
region_name=self._secret_key[1],
191-
endpoint_url=self._secret_key[2],
192-
)
187+
session = AwsCredentialsManager.get_session(host_info, props, self._secret_key[1])
188+
client = AwsCredentialsManager.get_client("secretsmanager", session, host_info.host, self._secret_key[1], self._secret_key[2])
193189

194190
secret = client.get_secret_value(
195191
SecretId=self._secret_key[0],
196192
)
197193

198-
client.close()
199-
200194
return loads(secret.get("SecretString"), object_hook=lambda d: SimpleNamespace(**d))
201195

202196
def _apply_secret_to_properties(self, properties: Properties):

aws_advanced_python_wrapper/cleanup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \
1616
OpenedConnectionTracker
17+
from aws_advanced_python_wrapper.aws_credentials_manager import \
18+
AwsCredentialsManager
1719
from aws_advanced_python_wrapper.host_monitoring_plugin import \
1820
MonitoringThreadContainer
1921
from aws_advanced_python_wrapper.thread_pool_container import \
@@ -26,5 +28,6 @@ def release_resources() -> None:
2628
"""Release all global resources used by the wrapper."""
2729
MonitoringThreadContainer.clean_up()
2830
ThreadPoolContainer.release_resources()
31+
AwsCredentialsManager.release_resources()
2932
OpenedConnectionTracker.release_resources()
3033
SlidingExpirationCacheContainer.release_resources()

aws_advanced_python_wrapper/credentials_provider_factory.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,29 @@
1616

1717
from typing import TYPE_CHECKING, Dict, Optional, Protocol
1818

19-
import boto3
20-
2119
if TYPE_CHECKING:
20+
from aws_advanced_python_wrapper.hostinfo import HostInfo
2221
from aws_advanced_python_wrapper.utils.properties import Properties
2322

2423
from abc import abstractmethod
2524

25+
from aws_advanced_python_wrapper.aws_credentials_manager import \
26+
AwsCredentialsManager
2627
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
2728

2829

2930
class CredentialsProviderFactory(Protocol):
3031
@abstractmethod
31-
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
32+
def get_aws_credentials(self, region: str, props: Properties, host_info: HostInfo) -> Optional[Dict[str, str]]:
3233
...
3334

3435

3536
class SamlCredentialsProviderFactory(CredentialsProviderFactory):
3637

37-
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
38+
def get_aws_credentials(self, region: str, props: Properties, host_info: HostInfo) -> Optional[Dict[str, str]]:
3839
saml_assertion: str = self.get_saml_assertion(props)
39-
session = boto3.Session()
40-
41-
sts_client = session.client(
42-
'sts',
43-
region_name=region
44-
)
40+
session = AwsCredentialsManager.get_session(host_info, props, region)
41+
sts_client = AwsCredentialsManager.get_client("sts", session, host_info.host, region)
4542

4643
response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
4744
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414

1515
from __future__ import annotations
1616

17+
from copy import deepcopy
1718
from html import unescape
1819
from re import DOTALL, findall, search
1920
from typing import TYPE_CHECKING, List
2021
from urllib.parse import urlencode
2122

23+
from aws_advanced_python_wrapper.aws_credentials_manager import \
24+
AwsCredentialsManager
2225
from aws_advanced_python_wrapper.credentials_provider_factory import (
2326
CredentialsProviderFactory, SamlCredentialsProviderFactory)
2427
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
2528
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
2629
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils
2730

2831
if TYPE_CHECKING:
29-
from boto3 import Session
3032
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
3133
from aws_advanced_python_wrapper.hostinfo import HostInfo
3234
from aws_advanced_python_wrapper.pep249 import Connection
@@ -55,10 +57,9 @@ class FederatedAuthPlugin(Plugin):
5557
_rds_utils: RdsUtils = RdsUtils()
5658
_token_cache: Dict[str, TokenInfo] = {}
5759

58-
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
60+
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory):
5961
self._plugin_service = plugin_service
6062
self._credentials_provider_factory = credentials_provider_factory
61-
self._session = session
6263

6364
self._region_utils = RegionUtils()
6465
telemetry_factory = self._plugin_service.get_telemetry_factory()
@@ -100,11 +101,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
100101

101102
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)
102103

104+
token_host_info = deepcopy(host_info)
105+
token_host_info.host = host
103106
if token_info is not None and not token_info.is_expired():
104107
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
105108
self._plugin_service.driver_dialect.set_password(props, token_info.token)
106109
else:
107-
self._update_authentication_token(host_info, props, user, region, cache_key)
110+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
108111

109112
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
110113

@@ -114,7 +117,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
114117
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
115118
raise e
116119

117-
self._update_authentication_token(host_info, props, user, region, cache_key)
120+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
118121

119122
try:
120123
return connect_func()
@@ -142,18 +145,19 @@ def _update_authentication_token(self,
142145
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
143146
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
144147
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
145-
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
148+
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props, host_info)
146149

147150
if self._fetch_token_counter is not None:
148151
self._fetch_token_counter.inc()
152+
session = AwsCredentialsManager.get_session(host_info, props, region)
149153
token: str = IamAuthUtils.generate_authentication_token(
150154
self._plugin_service,
151155
user,
152156
host_info.host,
153157
port,
154158
region,
155-
credentials,
156-
self._session)
159+
session,
160+
credentials)
157161
WrapperProperties.PASSWORD.set(props, token)
158162
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
159163

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
from copy import deepcopy
1718
from typing import TYPE_CHECKING
1819

20+
from aws_advanced_python_wrapper.aws_credentials_manager import \
21+
AwsCredentialsManager
1922
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
2023
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
2124

2225
if TYPE_CHECKING:
23-
from boto3 import Session
2426
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2527
from aws_advanced_python_wrapper.hostinfo import HostInfo
2628
from aws_advanced_python_wrapper.pep249 import Connection
2729
from aws_advanced_python_wrapper.plugin_service import PluginService
2830

2931
from datetime import datetime, timedelta
30-
from typing import Callable, Dict, Optional, Set
32+
from typing import Callable, Dict, Set
3133

3234
from aws_advanced_python_wrapper.errors import AwsWrapperError
3335
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
@@ -49,9 +51,8 @@ class IamAuthPlugin(Plugin):
4951
_rds_utils: RdsUtils = RdsUtils()
5052
_token_cache: Dict[str, TokenInfo] = {}
5153

52-
def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
54+
def __init__(self, plugin_service: PluginService):
5355
self._plugin_service = plugin_service
54-
self._session = session
5556

5657
self._region_utils = RegionUtils()
5758
telemetry_factory = self._plugin_service.get_telemetry_factory()
@@ -104,7 +105,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
104105
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
105106
if self._fetch_token_counter is not None:
106107
self._fetch_token_counter.inc()
107-
token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
108+
109+
session_host_info = deepcopy(host_info)
110+
session_host_info.host = host
111+
session = AwsCredentialsManager.get_session(host_info, props, region)
112+
token: str = IamAuthUtils.generate_authentication_token(
113+
self._plugin_service,
114+
user,
115+
host,
116+
port,
117+
region,
118+
session)
108119
self._plugin_service.driver_dialect.set_password(props, token)
109120
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
110121

@@ -123,7 +134,11 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
123134
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
124135
if self._fetch_token_counter is not None:
125136
self._fetch_token_counter.inc()
126-
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
137+
138+
session_host_info = deepcopy(host_info)
139+
session_host_info.host = host
140+
session = AwsCredentialsManager.get_session(session_host_info, props, region)
141+
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, session)
127142
self._plugin_service.driver_dialect.set_password(props, token)
128143
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
129144

0 commit comments

Comments
 (0)