Skip to content

Commit 216b956

Browse files
fix: sliding expiration cache concurrent access exceptions (#1089)
1 parent 26cfe23 commit 216b956

3 files changed

Lines changed: 32 additions & 42 deletions

File tree

aws_advanced_python_wrapper/utils/concurrent.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Dict, Iterator, Set, Union, ValuesView
18-
19-
if TYPE_CHECKING:
20-
from typing import ItemsView
21-
2217
from threading import Condition, Lock, RLock
23-
from typing import Callable, Generic, KeysView, List, Optional, TypeVar
18+
from typing import (Callable, Dict, Generic, Iterator, List, Optional, Set,
19+
TypeVar, Union)
2420

2521
K = TypeVar('K')
2622
V = TypeVar('V')
@@ -111,14 +107,20 @@ def apply_if(self, predicate: Callable, apply: Callable):
111107
if predicate(key, value):
112108
apply(key, value)
113109

114-
def keys(self) -> KeysView:
115-
return self._dict.keys()
110+
def keys(self) -> List[K]:
111+
"""Returns a thread-safe snapshot of keys."""
112+
with self._lock:
113+
return list(self._dict.keys())
116114

117-
def values(self) -> ValuesView:
118-
return self._dict.values()
115+
def values(self) -> List[V]:
116+
"""Returns a thread-safe snapshot of values."""
117+
with self._lock:
118+
return list(self._dict.values())
119119

120-
def items(self) -> ItemsView:
121-
return self._dict.items()
120+
def items(self) -> List[tuple[K, V]]:
121+
"""Returns a thread-safe snapshot of items."""
122+
with self._lock:
123+
return list(self._dict.items())
122124

123125

124126
class ConcurrentSet(Generic[V]):

aws_advanced_python_wrapper/utils/sliding_expiration_cache.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from threading import Thread
1818
from time import perf_counter_ns, sleep
19-
from typing import Callable, Generic, ItemsView, KeysView, Optional, TypeVar
19+
from typing import Callable, Generic, List, Optional, Tuple, TypeVar
2020

2121
from aws_advanced_python_wrapper.utils.atomic import AtomicInt
2222
from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict
@@ -46,10 +46,10 @@ def __len__(self):
4646
def set_cleanup_interval_ns(self, interval_ns):
4747
self._cleanup_interval_ns = interval_ns
4848

49-
def keys(self) -> KeysView:
49+
def keys(self) -> List[K]:
5050
return self._cdict.keys()
5151

52-
def items(self) -> ItemsView:
52+
def items(self) -> List[Tuple[K, CacheItem[V]]]:
5353
return self._cdict.items()
5454

5555
def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]:
@@ -73,32 +73,28 @@ def _remove_and_dispose(self, key: K):
7373
self._item_disposal_func(cache_item.item)
7474

7575
def _remove_if_expired(self, key: K):
76-
item = None
77-
7876
def _remove_if_expired_internal(_, cache_item):
7977
if self._should_cleanup_item(cache_item):
80-
nonlocal item
81-
item = cache_item.item
78+
# Dispose while holding the lock to prevent race conditions
79+
if self._item_disposal_func is not None:
80+
self._item_disposal_func(cache_item.item)
8281
return None
83-
8482
return cache_item
8583

8684
self._cdict.compute_if_present(key, _remove_if_expired_internal)
8785

88-
if item is None or self._item_disposal_func is None:
89-
return
90-
91-
self._item_disposal_func(item)
92-
9386
def _should_cleanup_item(self, cache_item: CacheItem) -> bool:
9487
if self._should_dispose_func is not None:
9588
return perf_counter_ns() > cache_item.expiration_time and self._should_dispose_func(cache_item.item)
9689
return perf_counter_ns() > cache_item.expiration_time
9790

9891
def clear(self):
99-
for _, cache_item in self._cdict.items():
100-
if cache_item is not None and self._item_disposal_func is not None:
101-
self._item_disposal_func(cache_item.item)
92+
# Dispose all items while holding the lock
93+
if self._item_disposal_func is not None:
94+
self._cdict.apply_if(
95+
lambda k, v: True, # Apply to all items
96+
lambda k, cache_item: self._item_disposal_func(cache_item.item)
97+
)
10298
self._cdict.clear()
10399

104100
def _cleanup(self):
@@ -107,7 +103,7 @@ def _cleanup(self):
107103
return
108104

109105
self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns)
110-
keys = [key for key, _ in self._cdict.items()]
106+
keys = self._cdict.keys()
111107
for key in keys:
112108
self._remove_if_expired(key)
113109

@@ -129,29 +125,21 @@ def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_e
129125
return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item
130126

131127
def _remove_if_disposable(self, key: K):
132-
item = None
133-
134128
def _remove_if_disposable_internal(_, cache_item):
135129
if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item):
136-
nonlocal item
137-
item = cache_item.item
130+
if self._item_disposal_func is not None:
131+
self._item_disposal_func(cache_item.item)
138132
return None
139-
140133
return cache_item
141134

142135
self._cdict.compute_if_present(key, _remove_if_disposable_internal)
143136

144-
if item is None or self._item_disposal_func is None:
145-
return
146-
147-
self._item_disposal_func(item)
148-
149137
def _cleanup_thread_internal(self):
150138
while True:
151139
try:
152140
sleep(self._cleanup_interval_ns / 1_000_000_000)
153141
self._cleanup_time_ns.set(perf_counter_ns() + self._cleanup_interval_ns)
154-
keys = [key for key, _ in self._cdict.items()]
142+
keys = self._cdict.keys()
155143
for key in keys:
156144
try:
157145
self._remove_if_expired(key)

tests/unit/test_sql_alchemy_pooled_connection_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def clear_cache():
6565

6666
def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool):
6767
expected_urls = {host_info.url}
68-
expected_keys = {PoolKey(host_info.url, "user1")}
68+
expected_keys = [PoolKey(host_info.url, "user1")]
6969
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})
7070

7171
conn = provider.connect(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), host_info, props)
@@ -76,7 +76,7 @@ def test_connect__default_mapping__default_pool_configuration(provider, host_inf
7676

7777

7878
def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn, mock_pool):
79-
expected_keys = {PoolKey(host_info.url, f"{host_info.url}+some_unique_key")}
79+
expected_keys = [PoolKey(host_info.url, f"{host_info.url}+some_unique_key")]
8080
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})
8181
attempt_creator_override_func = mocker.MagicMock()
8282

0 commit comments

Comments
 (0)