33from sqlalchemy .dialects .postgresql .psycopg import PGDialect_psycopg
44import re
55
6+ from aws_advanced_python_wrapper import AwsWrapperConnection
7+
8+
69class SqlAlchemyOrmPgDialect (PGDialect_psycopg ):
710 """
8- SQLAlchemy dialect for AWS Advanced Python Wrapper.
9- Extends PostgreSQL psycopg dialect with Aurora-aware connection handling.
11+ SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect.
12+ This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used
13+ directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection
14+ string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]").
1015 """
1116
1217 name = 'postgresql'
1318 driver = 'aws_wrapper'
1419
1520 def __init__ (self , ** kwargs ):
16- # Skip parent's version check since we're a wrapper, not psycopg itself
21+ # PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than
22+ # 3.0.2. If we call it, the exception is raised because it mistakenly interprets our driver version as its own.
23+ # As a workaround we call the grandparent __init__ instead of the parent's __init__.
24+ # TODO: since we are calling the grandparent's __init__ instead of the parent's __init__, we should investigate
25+ # whether any important code in the parent's __init__ needs to be executed.
1726 super (PGDialect_psycopg , self ).__init__ (** kwargs )
1827
19- # Dynamically detect the actual psycopg version we're wrapping to ensure
20- # SQLAlchemy uses the correct feature set and SQL generation
28+ # Dynamically detect the actual psycopg version installed and set it as self.psycopg_version. Note that setting
29+ # this field before calling super().__init__ does not avoid the issue noted above.
2130 try :
2231 import psycopg
2332 m = re .match (r"(\d+)\.(\d+)(?:\.(\d+))?" , psycopg .__version__ )
@@ -26,8 +35,10 @@ def __init__(self, **kwargs):
2635 int (x ) for x in m .group (1 , 2 , 3 ) if x is not None
2736 )
2837 else :
29- self .psycopg_version = (3 , 0 , 2 ) # Minimum supported
38+ # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version.
39+ self .psycopg_version = (3 , 0 , 2 )
3040 except (ImportError , AttributeError ):
41+ # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version.
3142 self .psycopg_version = (3 , 0 , 2 )
3243
3344 @classmethod
@@ -42,15 +53,15 @@ def import_dbapi(cls):
4253 def create_connect_args (self , url ):
4354 """
4455 Transform SQLAlchemy URL into connection arguments.
45- Must include 'target' parameter for the wrapper.
56+ Must include the 'target' parameter for our wrapper driver .
4657 """
4758 # Extract standard connection parameters
4859 opts = url .translate_connect_args (username = 'user' )
4960
5061 # Add query string parameters
5162 opts .update (url .query )
5263
53- # Add the required 'target' parameter for your wrapper
64+ # Add the required 'target' parameter for our wrapper
5465 if 'target' not in opts :
5566 opts ['target' ] = Connection .connect
5667
@@ -62,7 +73,6 @@ def on_connect(self):
6273 Return a callable that will be executed on new connections. This can be used if we need to set any session-level
6374 parameters.
6475 """
65-
6676 def set_session_params (conn ):
6777 # Set any Aurora-specific session parameters
6878 cursor = conn .cursor ()
@@ -75,64 +85,64 @@ def set_session_params(conn):
7585 return set_session_params
7686
7787 def get_isolation_level (self , dbapi_connection ):
78- """Get the current isolation level"""
7988 cursor = dbapi_connection .cursor ()
8089 try :
8190 cursor .execute ("SHOW transaction_isolation" )
8291 val = cursor .fetchone ()
8392 if val :
84- # Extract first element from tuple and format
8593 return val .upper ().replace (' ' , '_' )
86- return 'READ_COMMITTED' # PostgreSQL's default
94+ return 'READ_COMMITTED' # return Postgres' default isolation level.
8795 finally :
8896 cursor .close ()
8997
9098 def initialize (self , connection ):
9199 """
92100 Override initialization to handle type introspection.
93101 The parent class tries to use TypeInfo.fetch() which requires
94- a native psycopg connection, not our wrapper .
102+ a native psycopg connection, not AwsWrapperConnection .
95103 """
96- # Find the AwsWrapperConnection at whatever nesting level
97- wrapper_conn = self ._get_wrapper_connection (connection )
98-
99- if wrapper_conn and hasattr (wrapper_conn , 'connection' ):
100- # Get the underlying psycopg connection
101- underlying_conn = wrapper_conn .connection
104+ # Unwrap SQLAlchemy's connection object
105+ wrapper_conn , wrapper_parent = self ._get_wrapper_connection_and_parent (connection )
102106
103- # Temporarily swap the entire connection chain
104- original_dbapi_conn = connection .connection
105- connection .connection = underlying_conn
107+ # Check if wrapper_conn and wrapper_parent expose their underlying connections
108+ if wrapper_conn and hasattr (wrapper_conn , 'connection' ) and wrapper_parent and hasattr (wrapper_parent .connection , 'connection' ):
109+ # Temporarily remove the AwsWrapperConnection from the connection chain
110+ psycopg_conn = wrapper_conn .connection
111+ wrapper_parent .connection = psycopg_conn
106112
107113 try :
108- # Call parent initialization with native psycopg connection
109114 super ().initialize (connection )
110115 finally :
111- # Restore original connection chain
112- connection .connection = original_dbapi_conn
116+ # Restore wrapper connection in the connection chain.
117+ wrapper_parent .connection = wrapper_conn
113118 else :
114- # If we can't find wrapper or it doesn't expose underlying connection,
115- # skip type introspection ( custom types won't be auto-configured)
119+ # If unable to swap underlying pscyopg connection, skip type introspection.
120+ # This means custom types (hstore, json, etc.) won't be auto-configured.
116121 pass
117122
118- def _get_wrapper_connection (self , connection ):
123+ def _get_wrapper_connection_and_parent (self , connection ):
119124 """
120- Traverse the connection chain to find AwsWrapperConnection.
121- Handles variable nesting depths depending on pool configuration.
122- """
123- from aws_advanced_python_wrapper import AwsWrapperConnection
125+ Traverse the connection chain to find AwsWrapperConnection and its parent connection .
126+
127+ Args:
128+ connection: SQLAlchemy Connection object
124129
130+ Returns:
131+ AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None
132+ """
125133 # Start with the DBAPI connection
126- current = connection .connection
134+ parent = connection
135+ child = connection .connection
127136
128137 # Traverse up to 5 levels deep (reasonable limit)
129138 for _ in range (5 ):
130- if isinstance (current , AwsWrapperConnection ):
131- return current
139+ if isinstance (child , AwsWrapperConnection ):
140+ return child , parent
132141
133142 # Try to go deeper if there's a .connection attribute
134- if hasattr (current , 'connection' ):
135- current = current .connection
143+ if hasattr (child , 'connection' ):
144+ parent = child
145+ child = child .connection
136146 else :
137147 break
138148
0 commit comments