Skip to content

Commit 6cd61c7

Browse files
aaron-congojonathanl-bq
authored andcommitted
Cleanup
1 parent 72457dc commit 6cd61c7

2 files changed

Lines changed: 68 additions & 62 deletions

File tree

aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,30 @@
33
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
44
import re
55

6+
from aws_advanced_python_wrapper import AwsWrapperConnection
7+
8+
69
class SqlAlchemyOrmPgDialect(PGDialect_psycopg):
710
"""
8-
SQLAlchemy dialect for AWS Advanced Python Wrapper.
9-
Extends PostgreSQL psycopg dialect with Aurora-aware connection handling.
11+
SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect.
12+
This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used
13+
directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection
14+
string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]").
1015
"""
1116

1217
name = 'postgresql'
1318
driver = 'aws_wrapper'
1419

1520
def __init__(self, **kwargs):
16-
# Skip parent's version check since we're a wrapper, not psycopg itself
21+
# PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than
22+
# 3.0.2. If we call it, the exception is raised because it mistakenly interprets our driver version as its own.
23+
# As a workaround we call the grandparent __init__ instead of the parent's __init__.
24+
# TODO: since we are calling the grandparent's __init__ instead of the parent's __init__, we should investigate
25+
# whether any important code in the parent's __init__ needs to be executed.
1726
super(PGDialect_psycopg, self).__init__(**kwargs)
1827

19-
# Dynamically detect the actual psycopg version we're wrapping to ensure
20-
# SQLAlchemy uses the correct feature set and SQL generation
28+
# Dynamically detect the actual psycopg version installed and set it as self.psycopg_version. Note that setting
29+
# this field before calling super().__init__ does not avoid the issue noted above.
2130
try:
2231
import psycopg
2332
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__)
@@ -26,8 +35,10 @@ def __init__(self, **kwargs):
2635
int(x) for x in m.group(1, 2, 3) if x is not None
2736
)
2837
else:
29-
self.psycopg_version = (3, 0, 2) # Minimum supported
38+
# Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version.
39+
self.psycopg_version = (3, 0, 2)
3040
except (ImportError, AttributeError):
41+
# Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version.
3142
self.psycopg_version = (3, 0, 2)
3243

3344
@classmethod
@@ -42,15 +53,15 @@ def import_dbapi(cls):
4253
def create_connect_args(self, url):
4354
"""
4455
Transform SQLAlchemy URL into connection arguments.
45-
Must include 'target' parameter for the wrapper.
56+
Must include the 'target' parameter for our wrapper driver.
4657
"""
4758
# Extract standard connection parameters
4859
opts = url.translate_connect_args(username='user')
4960

5061
# Add query string parameters
5162
opts.update(url.query)
5263

53-
# Add the required 'target' parameter for your wrapper
64+
# Add the required 'target' parameter for our wrapper
5465
if 'target' not in opts:
5566
opts['target'] = Connection.connect
5667

@@ -62,7 +73,6 @@ def on_connect(self):
6273
Return a callable that will be executed on new connections. This can be used if we need to set any session-level
6374
parameters.
6475
"""
65-
6676
def set_session_params(conn):
6777
# Set any Aurora-specific session parameters
6878
cursor = conn.cursor()
@@ -75,64 +85,64 @@ def set_session_params(conn):
7585
return set_session_params
7686

7787
def get_isolation_level(self, dbapi_connection):
78-
"""Get the current isolation level"""
7988
cursor = dbapi_connection.cursor()
8089
try:
8190
cursor.execute("SHOW transaction_isolation")
8291
val = cursor.fetchone()
8392
if val:
84-
# Extract first element from tuple and format
8593
return val.upper().replace(' ', '_')
86-
return 'READ_COMMITTED' # PostgreSQL's default
94+
return 'READ_COMMITTED' # return Postgres' default isolation level.
8795
finally:
8896
cursor.close()
8997

9098
def initialize(self, connection):
9199
"""
92100
Override initialization to handle type introspection.
93101
The parent class tries to use TypeInfo.fetch() which requires
94-
a native psycopg connection, not our wrapper.
102+
a native psycopg connection, not AwsWrapperConnection.
95103
"""
96-
# Find the AwsWrapperConnection at whatever nesting level
97-
wrapper_conn = self._get_wrapper_connection(connection)
98-
99-
if wrapper_conn and hasattr(wrapper_conn, 'connection'):
100-
# Get the underlying psycopg connection
101-
underlying_conn = wrapper_conn.connection
104+
# Unwrap SQLAlchemy's connection object
105+
wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection)
102106

103-
# Temporarily swap the entire connection chain
104-
original_dbapi_conn = connection.connection
105-
connection.connection = underlying_conn
107+
# Check if wrapper_conn and wrapper_parent expose their underlying connections
108+
if wrapper_conn and hasattr(wrapper_conn, 'connection') and wrapper_parent and hasattr(wrapper_parent.connection, 'connection'):
109+
# Temporarily remove the AwsWrapperConnection from the connection chain
110+
psycopg_conn = wrapper_conn.connection
111+
wrapper_parent.connection = psycopg_conn
106112

107113
try:
108-
# Call parent initialization with native psycopg connection
109114
super().initialize(connection)
110115
finally:
111-
# Restore original connection chain
112-
connection.connection = original_dbapi_conn
116+
# Restore wrapper connection in the connection chain.
117+
wrapper_parent.connection = wrapper_conn
113118
else:
114-
# If we can't find wrapper or it doesn't expose underlying connection,
115-
# skip type introspection (custom types won't be auto-configured)
119+
# If unable to swap underlying pscyopg connection, skip type introspection.
120+
# This means custom types (hstore, json, etc.) won't be auto-configured.
116121
pass
117122

118-
def _get_wrapper_connection(self, connection):
123+
def _get_wrapper_connection_and_parent(self, connection):
119124
"""
120-
Traverse the connection chain to find AwsWrapperConnection.
121-
Handles variable nesting depths depending on pool configuration.
122-
"""
123-
from aws_advanced_python_wrapper import AwsWrapperConnection
125+
Traverse the connection chain to find AwsWrapperConnection and its parent connection.
126+
127+
Args:
128+
connection: SQLAlchemy Connection object
124129
130+
Returns:
131+
AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None
132+
"""
125133
# Start with the DBAPI connection
126-
current = connection.connection
134+
parent = connection
135+
child = connection.connection
127136

128137
# Traverse up to 5 levels deep (reasonable limit)
129138
for _ in range(5):
130-
if isinstance(current, AwsWrapperConnection):
131-
return current
139+
if isinstance(child, AwsWrapperConnection):
140+
return child, parent
132141

133142
# Try to go deeper if there's a .connection attribute
134-
if hasattr(current, 'connection'):
135-
current = current.connection
143+
if hasattr(child, 'connection'):
144+
parent = child
145+
child = child.connection
136146
else:
137147
break
138148

tests/unit/test_sqlalchemy_orm.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,24 @@ class User(Base):
3939
Session = sessionmaker(bind=engine)
4040

4141
# Step 6: Use session for database operations
42-
session = Session()
43-
44-
# INSERT - Create new object and add to session
45-
new_user = User(name='John Doe', email='john@example.com')
46-
session.add(new_user)
47-
session.commit() # Explicit commit required
48-
49-
# SELECT - Query using session
50-
users = session.query(User).filter(User.name == 'John Doe').all()
51-
for user in users:
52-
print(f"{user.name}: {user.email}")
53-
54-
55-
# UPDATE - Modify object and commit
56-
user = session.query(User).filter(User.name == "John Doe").first()
57-
user.email = 'newemail@example.com'
58-
session.commit() # Changes tracked by session
59-
60-
# DELETE - Remove object from session
61-
user_to_delete = session.query(User).filter(User.name == "John Doe").first()
62-
session.delete(user_to_delete)
63-
session.commit()
64-
65-
# Always close session when done
66-
session.close()
42+
with Session() as session:
43+
# INSERT - Create new object and add to session
44+
new_user = User(name='John Doe', email='john@example.com')
45+
session.add(new_user)
46+
session.commit() # Explicit commit required
47+
48+
# SELECT - Query using session
49+
users = session.query(User).filter(User.name == 'John Doe').all()
50+
for user in users:
51+
print(f"{user.name}: {user.email}")
52+
53+
54+
# UPDATE - Modify object and commit
55+
user = session.query(User).filter(User.name == "John Doe").first()
56+
user.email = 'newemail@example.com'
57+
session.commit()
58+
59+
# DELETE - Remove object from session
60+
user_to_delete = session.query(User).filter(User.name == "John Doe").first()
61+
session.delete(user_to_delete)
62+
session.commit()

0 commit comments

Comments
 (0)