Skip to content

Commit 72457dc

Browse files
aaron-congojonathanl-bq
authored andcommitted
Simple PG workflow working
1 parent 03888f1 commit 72457dc

4 files changed

Lines changed: 230 additions & 1 deletion

File tree

aws_advanced_python_wrapper/__init__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,20 @@
1515
from logging import DEBUG, getLogger
1616

1717
from .cleanup import release_resources
18+
from .driver_info import DriverInfo
1819
from .utils.utils import LogUtils
1920
from .wrapper import AwsWrapperConnection
21+
from aws_advanced_python_wrapper.pep249 import (
22+
Error,
23+
InterfaceError,
24+
DatabaseError,
25+
DataError,
26+
OperationalError,
27+
IntegrityError,
28+
InternalError,
29+
ProgrammingError,
30+
NotSupportedError
31+
)
2032

2133
# PEP249 compliance
2234
connect = AwsWrapperConnection.connect
@@ -32,9 +44,19 @@
3244
'set_logger',
3345
'apilevel',
3446
'threadsafety',
35-
'paramstyle'
47+
'paramstyle',
48+
'Error',
49+
'InterfaceError',
50+
'DatabaseError',
51+
'DataError',
52+
'OperationalError',
53+
'IntegrityError',
54+
'InternalError',
55+
'ProgrammingError',
56+
'NotSupportedError'
3657
]
3758

59+
__version__ = DriverInfo.DRIVER_VERSION
3860

3961
def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None):
4062
LogUtils.setup_logger(getLogger(name), level, format_string)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py
2+
from psycopg import Connection
3+
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
4+
import re
5+
6+
class SqlAlchemyOrmPgDialect(PGDialect_psycopg):
7+
"""
8+
SQLAlchemy dialect for AWS Advanced Python Wrapper.
9+
Extends PostgreSQL psycopg dialect with Aurora-aware connection handling.
10+
"""
11+
12+
name = 'postgresql'
13+
driver = 'aws_wrapper'
14+
15+
def __init__(self, **kwargs):
16+
# Skip parent's version check since we're a wrapper, not psycopg itself
17+
super(PGDialect_psycopg, self).__init__(**kwargs)
18+
19+
# Dynamically detect the actual psycopg version we're wrapping to ensure
20+
# SQLAlchemy uses the correct feature set and SQL generation
21+
try:
22+
import psycopg
23+
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__)
24+
if m:
25+
self.psycopg_version = tuple(
26+
int(x) for x in m.group(1, 2, 3) if x is not None
27+
)
28+
else:
29+
self.psycopg_version = (3, 0, 2) # Minimum supported
30+
except (ImportError, AttributeError):
31+
self.psycopg_version = (3, 0, 2)
32+
33+
@classmethod
34+
def import_dbapi(cls):
35+
"""
36+
Return the DB-API 2.0 module.
37+
SQLAlchemy calls this to get the driver module.
38+
"""
39+
import aws_advanced_python_wrapper
40+
return aws_advanced_python_wrapper
41+
42+
def create_connect_args(self, url):
43+
"""
44+
Transform SQLAlchemy URL into connection arguments.
45+
Must include 'target' parameter for the wrapper.
46+
"""
47+
# Extract standard connection parameters
48+
opts = url.translate_connect_args(username='user')
49+
50+
# Add query string parameters
51+
opts.update(url.query)
52+
53+
# Add the required 'target' parameter for your wrapper
54+
if 'target' not in opts:
55+
opts['target'] = Connection.connect
56+
57+
# Return empty args list and kwargs dict
58+
return ([], opts)
59+
60+
def on_connect(self):
61+
"""
62+
Return a callable that will be executed on new connections. This can be used if we need to set any session-level
63+
parameters.
64+
"""
65+
66+
def set_session_params(conn):
67+
# Set any Aurora-specific session parameters
68+
cursor = conn.cursor()
69+
try:
70+
# Example: Set statement timeout
71+
cursor.execute("SET statement_timeout = '60s'")
72+
finally:
73+
cursor.close()
74+
75+
return set_session_params
76+
77+
def get_isolation_level(self, dbapi_connection):
78+
"""Get the current isolation level"""
79+
cursor = dbapi_connection.cursor()
80+
try:
81+
cursor.execute("SHOW transaction_isolation")
82+
val = cursor.fetchone()
83+
if val:
84+
# Extract first element from tuple and format
85+
return val.upper().replace(' ', '_')
86+
return 'READ_COMMITTED' # PostgreSQL's default
87+
finally:
88+
cursor.close()
89+
90+
def initialize(self, connection):
91+
"""
92+
Override initialization to handle type introspection.
93+
The parent class tries to use TypeInfo.fetch() which requires
94+
a native psycopg connection, not our wrapper.
95+
"""
96+
# Find the AwsWrapperConnection at whatever nesting level
97+
wrapper_conn = self._get_wrapper_connection(connection)
98+
99+
if wrapper_conn and hasattr(wrapper_conn, 'connection'):
100+
# Get the underlying psycopg connection
101+
underlying_conn = wrapper_conn.connection
102+
103+
# Temporarily swap the entire connection chain
104+
original_dbapi_conn = connection.connection
105+
connection.connection = underlying_conn
106+
107+
try:
108+
# Call parent initialization with native psycopg connection
109+
super().initialize(connection)
110+
finally:
111+
# Restore original connection chain
112+
connection.connection = original_dbapi_conn
113+
else:
114+
# If we can't find wrapper or it doesn't expose underlying connection,
115+
# skip type introspection (custom types won't be auto-configured)
116+
pass
117+
118+
def _get_wrapper_connection(self, connection):
119+
"""
120+
Traverse the connection chain to find AwsWrapperConnection.
121+
Handles variable nesting depths depending on pool configuration.
122+
"""
123+
from aws_advanced_python_wrapper import AwsWrapperConnection
124+
125+
# Start with the DBAPI connection
126+
current = connection.connection
127+
128+
# Traverse up to 5 levels deep (reasonable limit)
129+
for _ in range(5):
130+
if isinstance(current, AwsWrapperConnection):
131+
return current
132+
133+
# Try to go deeper if there's a .connection attribute
134+
if hasattr(current, 'connection'):
135+
current = current.connection
136+
else:
137+
break
138+
139+
return None

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,5 @@ filterwarnings = [
8484
'ignore:Exception during reset or similar:pytest.PytestUnhandledThreadExceptionWarning'
8585
]
8686

87+
[tool.poetry.plugins."sqlalchemy.dialects"]
88+
"postgresql.aws_wrapper" = "aws_advanced_python_wrapper.sqlalchemy.orm_dialect:SqlAlchemyOrmPgDialect"

tests/unit/test_sqlalchemy_orm.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import create_engine, Column, Integer, String
16+
from sqlalchemy.ext.declarative import declarative_base
17+
from sqlalchemy.orm import sessionmaker
18+
19+
class TestSqlAlchemyORM:
20+
def test_basic_workflow(self):
21+
# Step 1: Create engine (connection to database)
22+
engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb')
23+
24+
# Step 2: Define base class for declarative models
25+
Base = declarative_base()
26+
27+
# Step 3: Define model class (separate from database operations)
28+
class User(Base):
29+
__tablename__ = 'users'
30+
31+
id = Column(Integer, primary_key=True)
32+
name = Column(String(50))
33+
email = Column(String(100))
34+
35+
# Step 4: Create tables
36+
Base.metadata.create_all(engine)
37+
38+
# Step 5: Create session factory
39+
Session = sessionmaker(bind=engine)
40+
41+
# Step 6: Use session for database operations
42+
session = Session()
43+
44+
# INSERT - Create new object and add to session
45+
new_user = User(name='John Doe', email='john@example.com')
46+
session.add(new_user)
47+
session.commit() # Explicit commit required
48+
49+
# SELECT - Query using session
50+
users = session.query(User).filter(User.name == 'John Doe').all()
51+
for user in users:
52+
print(f"{user.name}: {user.email}")
53+
54+
55+
# UPDATE - Modify object and commit
56+
user = session.query(User).filter(User.name == "John Doe").first()
57+
user.email = 'newemail@example.com'
58+
session.commit() # Changes tracked by session
59+
60+
# DELETE - Remove object from session
61+
user_to_delete = session.query(User).filter(User.name == "John Doe").first()
62+
session.delete(user_to_delete)
63+
session.commit()
64+
65+
# Always close session when done
66+
session.close()

0 commit comments

Comments
 (0)