Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed #27222 -- Refreshed model field values assigned expressions on …
…save().

Removed the can_return_columns_from_insert skip gates on existing field_defaults
tests to confirm the expected number of queries are performed and that
returning field overrides are respected.
  • Loading branch information
charettes committed Jun 14, 2025
commit 4b0253eb2358c0e84e3a2c2feabeccb1ef78d2a7
22 changes: 20 additions & 2 deletions django/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,11 @@ def _save_table(
and f.referenced_fields.intersection(non_pks_non_generated)
)
]
for field, _model, value in values:
if (update_fields is None or field.name in update_fields) and hasattr(
value, "resolve_expression"
):
returning_fields.append(field)
results = self._do_update(
base_qs,
using,
Expand Down Expand Up @@ -1140,7 +1145,15 @@ def _save_table(
for f in meta.local_concrete_fields
if not f.generated and (pk_set or f is not meta.auto_field)
]
returning_fields = meta.db_returning_fields
returning_fields = list(meta.db_returning_fields)
for field in fields:
value = (
getattr(self, field.attname) if raw else field.pre_save(self, False)
)
if hasattr(value, "resolve_expression"):
returning_fields.append(field)
elif field.db_returning:
returning_fields.remove(field)
results = self._do_insert(
cls._base_manager, using, fields, returning_fields, raw
)
Expand Down Expand Up @@ -1201,8 +1214,13 @@ def _do_insert(self, manager, using, fields, returning_fields, raw):
)

def _assign_returned_values(self, returned_values, returning_fields):
for value, field in zip(returned_values, returning_fields):
returning_fields_iter = iter(returning_fields)
for value, field in zip(returned_values, returning_fields_iter):
setattr(self, field.attname, value)
# Defer all fields that were meant to be updated with their database
# resolved values but couldn't as they are effectively stale.
for field in returning_fields_iter:
self.__dict__.pop(field.attname, None)

def _prepare_related_fields_for_save(self, operation_name, fields=None):
# Ensure that a model instance without a PK hasn't been assigned to
Expand Down
39 changes: 17 additions & 22 deletions docs/ref/models/expressions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ Some examples

# Create a new company using expressions.
>>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog")))
# Be sure to refresh it if you need to access the field.
>>> company.refresh_from_db()
>>> company.ticker
'GOOG'

Expand Down Expand Up @@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does,
through Django's ``F()`` class, is create the SQL syntax to refer to the field
and describe the operation.

To access the new value saved this way, the object must be reloaded::

reporter = Reporters.objects.get(pk=reporter.pk)
# Or, more succinctly:
reporter.refresh_from_db()

As well as being used in operations on single instances as above, ``F()`` can
be used with ``update()`` to perform bulk updates on a ``QuerySet``. This
reduces the two queries we were using above - the ``get()`` and the
Expand Down Expand Up @@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to
>>> writer = Writers.objects.get(name="Priyansh")
>>> writer.name = F("name")[1:5]
>>> writer.save()
>>> writer.refresh_from_db()
>>> writer.name
'riya'

Expand All @@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in
the database when the :meth:`~Model.save()` or ``update()`` is executed, rather
than based on its value when the instance was retrieved.

``F()`` assignments persist after ``Model.save()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` assignments are refreshed after ``Model.save()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

``F()`` objects assigned to model fields persist after saving the model
instance and will be applied on each :meth:`~Model.save()`. For example::
``F()`` objects assigned to model fields are refreshed from the database
on :meth:`~Model.save()` on backends that support it without incurring a
subsequent query (SQLite, Postgres, and Oracle) and deferred otherwise
(MySQL). For example:

reporter = Reporters.objects.get(name="Tintin")
reporter.stories_filed = F("stories_filed") + 1
reporter.save()
.. code-block:: pycon

reporter.name = "Tintin Jr."
reporter.save()
>>> reporter = Reporters.objects.get(name="Tintin")
>>> reporter.stories_filed = F("stories_filed") + 1
>>> reporter.save()
>>> reporter.stories_filed # This triggers a refresh query on MySQL
14 # Assuming the database value was 13 when the object was saved.

.. versionchanged:: 6.0

``stories_filed`` will be updated twice in this case. If it's initially ``1``,
the final value will be ``3``. This persistence can be avoided by reloading the
model object after saving it, for example, by using
:meth:`~Model.refresh_from_db()`.
In previous versions of Django, ``F()`` objects were not refreshed from the
database on :meth:`~Model.save()` which resulted in them being evaluated
and persisted every time the instance was saved.

Using ``F()`` in filters
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions docs/releases/6.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ Models
support it (MySQL and MariaDB), the fields are marked as deferred to trigger
a refresh on subsequent accesses.

* :class:`~django.db.models.GeneratedField` and :ref:`fields assigned
expressions <avoiding-race-conditions-using-f>` are now refreshed from the
database after :meth:`~django.db.models.Model.save` on backends that support
the ``RETURNING`` clause (SQLite, Postgres, and Oracle). On backends that
don't support it (MySQL and MariaDB), the fields are marked as deferred to
trigger a refresh on subsequent accesses.

Pagination
~~~~~~~~~~

Expand Down
22 changes: 22 additions & 0 deletions tests/basic/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import threading
import time
from datetime import datetime, timedelta
from unittest import mock

Expand All @@ -12,6 +13,7 @@
models,
transaction,
)
from django.db.models.functions import Now
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import (
Expand Down Expand Up @@ -557,6 +559,26 @@ def new_instance():
with self.subTest(case=case):
self.assertIs(case._is_pk_set(), True)

def test_save_expressions(self):
article = Article(pub_date=Now())
article.save()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
article_pub_date = article.pub_date
self.assertIsInstance(article_pub_date, datetime)
# Sleep slightly to ensure at different database level NOW().
time.sleep(0.1)
article.pub_date = Now()
article.save()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertIsInstance(article.pub_date, datetime)
self.assertGreater(article.pub_date, article_pub_date)


class ModelLookupTest(TestCase):
@classmethod
Expand Down
14 changes: 10 additions & 4 deletions tests/expressions/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ def test_object_update(self):
# F expressions can be used to update attributes on single objects
self.gmbh.num_employees = F("num_employees") + 4
self.gmbh.save()
self.gmbh.refresh_from_db()
self.assertEqual(self.gmbh.num_employees, 36)
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(self.gmbh.num_employees, 36)

def test_new_object_save(self):
# We should be able to use Funcs when inserting new data
Expand Down Expand Up @@ -1596,8 +1599,11 @@ def test_decimal_expression(self):
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
n.decimal_value = F("decimal_value") - Decimal("0.4")
n.save()
n.refresh_from_db()
self.assertEqual(n.decimal_value, Decimal("0.1"))
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(n.decimal_value, Decimal("0.1"))


class ExpressionOperatorTests(TestCase):
Expand Down
70 changes: 37 additions & 33 deletions tests/field_defaults/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@
)
from django.db.models.functions import Collate
from django.db.models.lookups import GreaterThan
from django.test import (
SimpleTestCase,
TestCase,
override_settings,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
from django.utils import timezone

from .models import (
Expand All @@ -44,47 +38,56 @@ def test_field_defaults(self):
self.assertEqual(a.headline, "Default headline")
self.assertLess((now - a.pub_date).seconds, 5)

@skipUnlessDBFeature(
"can_return_columns_from_insert", "supports_expression_defaults"
)
@skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_returning(self):
a = DBArticle()
a.save()
self.assertIsInstance(a.id, int)
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 3
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))

@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_refresh(self):
a = DBArticle()
a.save()
a.refresh_from_db()
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 3
)
self.assertIsInstance(a.id, int)
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))
with self.assertNumQueries(expected_num_queries):
self.assertEqual(a.headline, "Default headline")
self.assertIsInstance(a.pub_date, datetime)
self.assertEqual(a.cost, Decimal("3.33"))

def test_null_db_default(self):
obj1 = DBDefaults.objects.create()
if not connection.features.can_return_columns_from_insert:
obj1.refresh_from_db()
self.assertEqual(obj1.null, 1.1)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj1.null, 1.1)

obj2 = DBDefaults.objects.create(null=None)
self.assertIsNone(obj2.null)
with self.assertNumQueries(0):
self.assertIsNone(obj2.null)

@skipUnlessDBFeature("supports_expression_defaults")
@override_settings(USE_TZ=True)
def test_db_default_function(self):
m = DBDefaultsFunction.objects.create()
if not connection.features.can_return_columns_from_insert:
m.refresh_from_db()
self.assertAlmostEqual(m.number, pi)
self.assertEqual(m.year, timezone.now().year)
self.assertAlmostEqual(m.added, pi + 4.5)
self.assertEqual(m.multiple_subfunctions, 4.5)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 4
)
with self.assertNumQueries(expected_num_queries):
self.assertAlmostEqual(m.number, pi)
self.assertEqual(m.year, timezone.now().year)
self.assertAlmostEqual(m.added, pi + 4.5)
self.assertEqual(m.multiple_subfunctions, 4.5)

@skipUnlessDBFeature("insert_test_table_with_defaults")
def test_both_default(self):
Expand Down Expand Up @@ -125,14 +128,15 @@ def test_foreign_key_db_default(self):
child2 = DBDefaultsFK.objects.create(language_code=parent2)
self.assertEqual(child2.language_code, parent2)

@skipUnlessDBFeature(
"can_return_columns_from_insert", "supports_expression_defaults"
)
@skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_returning(self):
m = DBDefaultsFunction.objects.create()
self.assertEqual(m.case_when, 3)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.case_when, 3)

@skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_no_returning(self):
m = DBDefaultsFunction.objects.create()
Expand Down
16 changes: 15 additions & 1 deletion tests/update_only_fields/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.exceptions import ObjectNotUpdated
from django.db import DatabaseError, transaction
from django.db import DatabaseError, connection, transaction
from django.db.models import F
from django.db.models.signals import post_save, pre_save
from django.test import TestCase

Expand Down Expand Up @@ -308,3 +309,16 @@ def test_update_fields_not_updated(self):
transaction.atomic(),
):
obj.save(update_fields=["name"])

def test_update_fields_expression(self):
obj = Person.objects.create(name="Valerie", gender="F", pid=42)
updated_pid = F("pid") + 1
obj.pid = updated_pid
obj.save(update_fields={"gender"})
self.assertIs(obj.pid, updated_pid)
obj.save(update_fields={"pid"})
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.pid, 43)