Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 91 additions & 24 deletions django/db/backends/postgresql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ def _get_varchar_column(data):
return "varchar(%(max_length)s)" % data


def ensure_timezone(connection, ops, timezone_name):
conn_timezone_name = connection.info.parameter_status("TimeZone")
if timezone_name and conn_timezone_name != timezone_name:
with connection.cursor() as cursor:
cursor.execute(ops.set_time_zone_sql(), [timezone_name])
return True
return False


def ensure_role(connection, ops, role_name):
if role_name:
with connection.cursor() as cursor:
sql = ops.compose_sql("SET ROLE %s", [role_name])
cursor.execute(sql)
return True
return False


class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
Expand Down Expand Up @@ -179,6 +197,42 @@ class DatabaseWrapper(BaseDatabaseWrapper):
ops_class = DatabaseOperations
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
_connection_pools = {}

@cached_property
def pool(self):
if self.alias not in self._connection_pools:
pool = None
pool_options = self.settings_dict.get("OPTIONS", {}).get("pool", None)

if pool_options is not None:
# Verify that we are not running with persistent connections
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
raise ImproperlyConfigured(
"Pooling doesn't support persistent connections"
)
if pool_options is True: # simply sets the default options
pool_options = {}

from psycopg_pool import ConnectionPool

connect_kwargs = self.get_connection_params()
# Ensure we run in autocommit, Django properly sets it later on
connect_kwargs["autocommit"] = True
pool = ConnectionPool(
kwargs=connect_kwargs,
open=False, # Do not open the pool during startup
configure=self._configure_connection,
**pool_options,
)

# NOTE: `setdefault` ensures that multiple threads don't set this in
# parallel. Since we do not open the pool during it's init above, this
# means that at worst during startup multiple threads generate pool
# objects and the first to set it wins.
self._connection_pools.setdefault(self.alias, pool)

return self._connection_pools[self.alias]

def get_database_version(self):
"""
Expand Down Expand Up @@ -223,6 +277,7 @@ def get_connection_params(self):

conn_params.pop("assume_role", None)
conn_params.pop("isolation_level", None)
conn_params.pop("pool", None)
server_side_binding = conn_params.pop("server_side_binding", None)
conn_params.setdefault(
"cursor_factory",
Expand Down Expand Up @@ -272,7 +327,12 @@ def get_new_connection(self, conn_params):
f"Invalid transaction isolation level {isolation_level_value} "
f"specified. Use one of the psycopg.IsolationLevel values."
)
connection = self.Database.connect(**conn_params)
if self.pool:
# If nothing else has opened the pool, open it now
self.pool.open()
connection = self.pool.getconn()
else:
connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
if not is_psycopg3:
Expand All @@ -287,36 +347,43 @@ def get_new_connection(self, conn_params):
def ensure_timezone(self):
if self.connection is None:
return False
conn_timezone_name = self.connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False

def ensure_role(self):
if self.connection is None:
return False
if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
with self.connection.cursor() as cursor:
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
cursor.execute(sql)
return True
return False
return ensure_timezone(self.connection, self.ops, self.timezone_name)

def init_connection_state(self):
super().init_connection_state()
def _configure_connection(self, connection):
# NOTE: This function is called from init_connection_state for non-pool
# connections and from the psycopg pool itself after a connection is
# opened. Please make sure that whatever is done here does not access
# anything on self aside from variables.

# Commit after setting the time zone.
commit_tz = self.ensure_timezone()
commit_tz = ensure_timezone(connection, self.ops, self.timezone_name)
# Set the role on the connection. This is useful if the credential used
# to login is not the same as the role that owns database resources. As
# can be the case when using temporary or ephemeral credentials.
commit_role = self.ensure_role()
role_name = self.settings_dict.get("OPTIONS", {}).get("assume_role")
commit_role = ensure_role(connection, self.ops, role_name)

return commit_role or commit_tz

def _close(self):
if self.connection is not None:
# NOTE: `wrap_database_error` only works for `putconn` as long as there
# is not `reset` function set in the pool because that is defered into
# a thread and not directly executed.
with self.wrap_database_errors:
if self.pool:
self.pool.putconn(self.connection)
else:
return self.connection.close()

def init_connection_state(self):
super().init_connection_state()

if self.connection is not None and not self.pool:
commit = self._configure_connection(self.connection)

if (commit_role or commit_tz) and not self.get_autocommit():
self.connection.commit()
if commit and not self.get_autocommit():
self.connection.commit()

@async_unsafe
def create_cursor(self, name=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements/postgres.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
psycopg[binary]>=3.1.8
psycopg[binary,pool]>=3.1.8