Skip to content

Commit 13b0d30

Browse files
charettesfelixxm
authored andcommitted
Fixed #27222 -- Refreshed model field values assigned expressions on save().
Removed the can_return_columns_from_insert skip gates on existing field_defaults tests to confirm the expected number of queries are performed and that returning field overrides are respected.
1 parent 92ef8f2 commit 13b0d30

File tree

7 files changed

+123
-63
lines changed

7 files changed

+123
-63
lines changed

django/db/models/base.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,11 @@ def _save_table(
11021102
and f.referenced_fields.intersection(non_pks_non_generated)
11031103
)
11041104
]
1105+
for field, _model, value in values:
1106+
if (update_fields is None or field.name in update_fields) and hasattr(
1107+
value, "resolve_expression"
1108+
):
1109+
returning_fields.append(field)
11051110
results = self._do_update(
11061111
base_qs,
11071112
using,
@@ -1142,7 +1147,15 @@ def _save_table(
11421147
for f in meta.local_concrete_fields
11431148
if not f.generated and (pk_set or f is not meta.auto_field)
11441149
]
1145-
returning_fields = meta.db_returning_fields
1150+
returning_fields = list(meta.db_returning_fields)
1151+
for field in fields:
1152+
value = (
1153+
getattr(self, field.attname) if raw else field.pre_save(self, False)
1154+
)
1155+
if hasattr(value, "resolve_expression"):
1156+
returning_fields.append(field)
1157+
elif field.db_returning:
1158+
returning_fields.remove(field)
11461159
results = self._do_insert(
11471160
cls._base_manager, using, fields, returning_fields, raw
11481161
)
@@ -1203,8 +1216,13 @@ def _do_insert(self, manager, using, fields, returning_fields, raw):
12031216
)
12041217

12051218
def _assign_returned_values(self, returned_values, returning_fields):
1206-
for value, field in zip(returned_values, returning_fields):
1219+
returning_fields_iter = iter(returning_fields)
1220+
for value, field in zip(returned_values, returning_fields_iter):
12071221
setattr(self, field.attname, value)
1222+
# Defer all fields that were meant to be updated with their database
1223+
# resolved values but couldn't as they are effectively stale.
1224+
for field in returning_fields_iter:
1225+
self.__dict__.pop(field.attname, None)
12081226

12091227
def _prepare_related_fields_for_save(self, operation_name, fields=None):
12101228
# Ensure that a model instance without a PK hasn't been assigned to

docs/ref/models/expressions.txt

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ Some examples
6969

7070
# Create a new company using expressions.
7171
>>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog")))
72-
# Be sure to refresh it if you need to access the field.
73-
>>> company.refresh_from_db()
7472
>>> company.ticker
7573
'GOOG'
7674

@@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does,
157155
through Django's ``F()`` class, is create the SQL syntax to refer to the field
158156
and describe the operation.
159157

160-
To access the new value saved this way, the object must be reloaded::
161-
162-
reporter = Reporters.objects.get(pk=reporter.pk)
163-
# Or, more succinctly:
164-
reporter.refresh_from_db()
165-
166158
As well as being used in operations on single instances as above, ``F()`` can
167159
be used with ``update()`` to perform bulk updates on a ``QuerySet``. This
168160
reduces the two queries we were using above - the ``get()`` and the
@@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to
199191
>>> writer = Writers.objects.get(name="Priyansh")
200192
>>> writer.name = F("name")[1:5]
201193
>>> writer.save()
202-
>>> writer.refresh_from_db()
203194
>>> writer.name
204195
'riya'
205196

@@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in
221212
the database when the :meth:`~Model.save` or ``update()`` is executed, rather
222213
than based on its value when the instance was retrieved.
223214

224-
``F()`` assignments persist after ``Model.save()``
225-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215+
``F()`` assignments are refreshed after ``Model.save()``
216+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
226217

227-
``F()`` objects assigned to model fields persist after saving the model
228-
instance and will be applied on each :meth:`~Model.save`. For example::
218+
``F()`` objects assigned to model fields are refreshed from the database
219+
on :meth:`~Model.save` on backends that support it without incurring a
220+
subsequent query (SQLite, Postgres, and Oracle) and deferred otherwise
221+
(MySQL). For example:
229222

230-
reporter = Reporters.objects.get(name="Tintin")
231-
reporter.stories_filed = F("stories_filed") + 1
232-
reporter.save()
223+
.. code-block:: pycon
233224

234-
reporter.name = "Tintin Jr."
235-
reporter.save()
225+
>>> reporter = Reporters.objects.get(name="Tintin")
226+
>>> reporter.stories_filed = F("stories_filed") + 1
227+
>>> reporter.save()
228+
>>> reporter.stories_filed # This triggers a refresh query on MySQL
229+
14 # Assuming the database value was 13 when the object was saved.
230+
231+
.. versionchanged:: 6.0
236232

237-
``stories_filed`` will be updated twice in this case. If it's initially ``1``,
238-
the final value will be ``3``. This persistence can be avoided by reloading the
239-
model object after saving it, for example, by using
240-
:meth:`~Model.refresh_from_db`.
233+
In previous versions of Django, ``F()`` objects were not refreshed from the
234+
database on :meth:`~Model.save` which resulted in them being evaluated and
235+
persisted every time the instance was saved.
241236

242237
Using ``F()`` in filters
243238
~~~~~~~~~~~~~~~~~~~~~~~~

docs/releases/6.0.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ Models
331331
value from the non-null input values. This is supported on SQLite, MySQL,
332332
Oracle, and PostgreSQL 16+.
333333

334-
* :class:`~django.db.models.GeneratedField`\s are now refreshed from the
334+
* :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned
335+
expressions <avoiding-race-conditions-using-f>` are now refreshed from the
335336
database after :meth:`~django.db.models.Model.save` on backends that support
336337
the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that
337338
don't support it (MySQL and MariaDB), the fields are marked as deferred to

tests/basic/tests.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import threading
3+
import time
34
from datetime import datetime, timedelta
45
from unittest import mock
56

@@ -12,6 +13,7 @@
1213
models,
1314
transaction,
1415
)
16+
from django.db.models.functions import Now
1517
from django.db.models.manager import BaseManager
1618
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
1719
from django.test import (
@@ -558,6 +560,26 @@ def new_instance():
558560
with self.subTest(case=case):
559561
self.assertIs(case._is_pk_set(), True)
560562

563+
def test_save_expressions(self):
564+
article = Article(pub_date=Now())
565+
article.save()
566+
expected_num_queries = (
567+
0 if connection.features.can_return_columns_from_insert else 1
568+
)
569+
with self.assertNumQueries(expected_num_queries):
570+
article_pub_date = article.pub_date
571+
self.assertIsInstance(article_pub_date, datetime)
572+
# Sleep slightly to ensure a different database level NOW().
573+
time.sleep(0.1)
574+
article.pub_date = Now()
575+
article.save()
576+
expected_num_queries = (
577+
0 if connection.features.can_return_rows_from_update else 1
578+
)
579+
with self.assertNumQueries(expected_num_queries):
580+
self.assertIsInstance(article.pub_date, datetime)
581+
self.assertGreater(article.pub_date, article_pub_date)
582+
561583

562584
class ModelLookupTest(TestCase):
563585
@classmethod

tests/expressions/tests.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,11 @@ def test_object_update(self):
420420
# F expressions can be used to update attributes on single objects
421421
self.gmbh.num_employees = F("num_employees") + 4
422422
self.gmbh.save()
423-
self.gmbh.refresh_from_db()
424-
self.assertEqual(self.gmbh.num_employees, 36)
423+
expected_num_queries = (
424+
0 if connection.features.can_return_rows_from_update else 1
425+
)
426+
with self.assertNumQueries(expected_num_queries):
427+
self.assertEqual(self.gmbh.num_employees, 36)
425428

426429
def test_new_object_save(self):
427430
# We should be able to use Funcs when inserting new data
@@ -1644,8 +1647,11 @@ def test_decimal_expression(self):
16441647
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
16451648
n.decimal_value = F("decimal_value") - Decimal("0.4")
16461649
n.save()
1647-
n.refresh_from_db()
1648-
self.assertEqual(n.decimal_value, Decimal("0.1"))
1650+
expected_num_queries = (
1651+
0 if connection.features.can_return_rows_from_update else 1
1652+
)
1653+
with self.assertNumQueries(expected_num_queries):
1654+
self.assertEqual(n.decimal_value, Decimal("0.1"))
16491655

16501656

16511657
class ExpressionOperatorTests(TestCase):

tests/field_defaults/tests.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515
)
1616
from django.db.models.functions import Collate
1717
from django.db.models.lookups import GreaterThan
18-
from django.test import (
19-
SimpleTestCase,
20-
TestCase,
21-
override_settings,
22-
skipIfDBFeature,
23-
skipUnlessDBFeature,
24-
)
18+
from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
2519
from django.utils import timezone
2620

2721
from .models import (
@@ -44,47 +38,56 @@ def test_field_defaults(self):
4438
self.assertEqual(a.headline, "Default headline")
4539
self.assertLess((now - a.pub_date).seconds, 5)
4640

47-
@skipUnlessDBFeature(
48-
"can_return_columns_from_insert", "supports_expression_defaults"
49-
)
41+
@skipUnlessDBFeature("supports_expression_defaults")
5042
def test_field_db_defaults_returning(self):
5143
a = DBArticle()
5244
a.save()
5345
self.assertIsInstance(a.id, int)
54-
self.assertEqual(a.headline, "Default headline")
55-
self.assertIsInstance(a.pub_date, datetime)
56-
self.assertEqual(a.cost, Decimal("3.33"))
46+
expected_num_queries = (
47+
0 if connection.features.can_return_columns_from_insert else 3
48+
)
49+
with self.assertNumQueries(expected_num_queries):
50+
self.assertEqual(a.headline, "Default headline")
51+
self.assertIsInstance(a.pub_date, datetime)
52+
self.assertEqual(a.cost, Decimal("3.33"))
5753

58-
@skipIfDBFeature("can_return_columns_from_insert")
5954
@skipUnlessDBFeature("supports_expression_defaults")
6055
def test_field_db_defaults_refresh(self):
6156
a = DBArticle()
6257
a.save()
63-
a.refresh_from_db()
58+
expected_num_queries = (
59+
0 if connection.features.can_return_columns_from_insert else 3
60+
)
6461
self.assertIsInstance(a.id, int)
65-
self.assertEqual(a.headline, "Default headline")
66-
self.assertIsInstance(a.pub_date, datetime)
67-
self.assertEqual(a.cost, Decimal("3.33"))
62+
with self.assertNumQueries(expected_num_queries):
63+
self.assertEqual(a.headline, "Default headline")
64+
self.assertIsInstance(a.pub_date, datetime)
65+
self.assertEqual(a.cost, Decimal("3.33"))
6866

6967
def test_null_db_default(self):
7068
obj1 = DBDefaults.objects.create()
71-
if not connection.features.can_return_columns_from_insert:
72-
obj1.refresh_from_db()
73-
self.assertEqual(obj1.null, 1.1)
69+
expected_num_queries = (
70+
0 if connection.features.can_return_columns_from_insert else 1
71+
)
72+
with self.assertNumQueries(expected_num_queries):
73+
self.assertEqual(obj1.null, 1.1)
7474

7575
obj2 = DBDefaults.objects.create(null=None)
76-
self.assertIsNone(obj2.null)
76+
with self.assertNumQueries(0):
77+
self.assertIsNone(obj2.null)
7778

7879
@skipUnlessDBFeature("supports_expression_defaults")
7980
@override_settings(USE_TZ=True)
8081
def test_db_default_function(self):
8182
m = DBDefaultsFunction.objects.create()
82-
if not connection.features.can_return_columns_from_insert:
83-
m.refresh_from_db()
84-
self.assertAlmostEqual(m.number, pi)
85-
self.assertEqual(m.year, timezone.now().year)
86-
self.assertAlmostEqual(m.added, pi + 4.5)
87-
self.assertEqual(m.multiple_subfunctions, 4.5)
83+
expected_num_queries = (
84+
0 if connection.features.can_return_columns_from_insert else 4
85+
)
86+
with self.assertNumQueries(expected_num_queries):
87+
self.assertAlmostEqual(m.number, pi)
88+
self.assertEqual(m.year, timezone.now().year)
89+
self.assertAlmostEqual(m.added, pi + 4.5)
90+
self.assertEqual(m.multiple_subfunctions, 4.5)
8891

8992
@skipUnlessDBFeature("insert_test_table_with_defaults")
9093
def test_both_default(self):
@@ -125,14 +128,15 @@ def test_foreign_key_db_default(self):
125128
child2 = DBDefaultsFK.objects.create(language_code=parent2)
126129
self.assertEqual(child2.language_code, parent2)
127130

128-
@skipUnlessDBFeature(
129-
"can_return_columns_from_insert", "supports_expression_defaults"
130-
)
131+
@skipUnlessDBFeature("supports_expression_defaults")
131132
def test_case_when_db_default_returning(self):
132133
m = DBDefaultsFunction.objects.create()
133-
self.assertEqual(m.case_when, 3)
134+
expected_num_queries = (
135+
0 if connection.features.can_return_columns_from_insert else 1
136+
)
137+
with self.assertNumQueries(expected_num_queries):
138+
self.assertEqual(m.case_when, 3)
134139

135-
@skipIfDBFeature("can_return_columns_from_insert")
136140
@skipUnlessDBFeature("supports_expression_defaults")
137141
def test_case_when_db_default_no_returning(self):
138142
m = DBDefaultsFunction.objects.create()

tests/update_only_fields/tests.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.core.exceptions import ObjectNotUpdated
2-
from django.db import DatabaseError, transaction
2+
from django.db import DatabaseError, connection, transaction
3+
from django.db.models import F
34
from django.db.models.signals import post_save, pre_save
45
from django.test import TestCase
56

@@ -308,3 +309,16 @@ def test_update_fields_not_updated(self):
308309
transaction.atomic(),
309310
):
310311
obj.save(update_fields=["name"])
312+
313+
def test_update_fields_expression(self):
314+
obj = Person.objects.create(name="Valerie", gender="F", pid=42)
315+
updated_pid = F("pid") + 1
316+
obj.pid = updated_pid
317+
obj.save(update_fields={"gender"})
318+
self.assertIs(obj.pid, updated_pid)
319+
obj.save(update_fields={"pid"})
320+
expected_num_queries = (
321+
0 if connection.features.can_return_rows_from_update else 1
322+
)
323+
with self.assertNumQueries(expected_num_queries):
324+
self.assertEqual(obj.pid, 43)

0 commit comments

Comments
 (0)