15
15
)
16
16
from django .db .models .functions import Collate
17
17
from django .db .models .lookups import GreaterThan
18
- from django .test import (
19
- SimpleTestCase ,
20
- TestCase ,
21
- override_settings ,
22
- skipIfDBFeature ,
23
- skipUnlessDBFeature ,
24
- )
18
+ from django .test import SimpleTestCase , TestCase , override_settings , skipUnlessDBFeature
25
19
from django .utils import timezone
26
20
27
21
from .models import (
@@ -44,47 +38,56 @@ def test_field_defaults(self):
44
38
self .assertEqual (a .headline , "Default headline" )
45
39
self .assertLess ((now - a .pub_date ).seconds , 5 )
46
40
47
- @skipUnlessDBFeature (
48
- "can_return_columns_from_insert" , "supports_expression_defaults"
49
- )
41
+ @skipUnlessDBFeature ("supports_expression_defaults" )
50
42
def test_field_db_defaults_returning (self ):
51
43
a = DBArticle ()
52
44
a .save ()
53
45
self .assertIsInstance (a .id , int )
54
- self .assertEqual (a .headline , "Default headline" )
55
- self .assertIsInstance (a .pub_date , datetime )
56
- self .assertEqual (a .cost , Decimal ("3.33" ))
46
+ expected_num_queries = (
47
+ 0 if connection .features .can_return_columns_from_insert else 3
48
+ )
49
+ with self .assertNumQueries (expected_num_queries ):
50
+ self .assertEqual (a .headline , "Default headline" )
51
+ self .assertIsInstance (a .pub_date , datetime )
52
+ self .assertEqual (a .cost , Decimal ("3.33" ))
57
53
58
- @skipIfDBFeature ("can_return_columns_from_insert" )
59
54
@skipUnlessDBFeature ("supports_expression_defaults" )
60
55
def test_field_db_defaults_refresh (self ):
61
56
a = DBArticle ()
62
57
a .save ()
63
- a .refresh_from_db ()
58
+ expected_num_queries = (
59
+ 0 if connection .features .can_return_columns_from_insert else 3
60
+ )
64
61
self .assertIsInstance (a .id , int )
65
- self .assertEqual (a .headline , "Default headline" )
66
- self .assertIsInstance (a .pub_date , datetime )
67
- self .assertEqual (a .cost , Decimal ("3.33" ))
62
+ with self .assertNumQueries (expected_num_queries ):
63
+ self .assertEqual (a .headline , "Default headline" )
64
+ self .assertIsInstance (a .pub_date , datetime )
65
+ self .assertEqual (a .cost , Decimal ("3.33" ))
68
66
69
67
def test_null_db_default (self ):
70
68
obj1 = DBDefaults .objects .create ()
71
- if not connection .features .can_return_columns_from_insert :
72
- obj1 .refresh_from_db ()
73
- self .assertEqual (obj1 .null , 1.1 )
69
+ expected_num_queries = (
70
+ 0 if connection .features .can_return_columns_from_insert else 1
71
+ )
72
+ with self .assertNumQueries (expected_num_queries ):
73
+ self .assertEqual (obj1 .null , 1.1 )
74
74
75
75
obj2 = DBDefaults .objects .create (null = None )
76
- self .assertIsNone (obj2 .null )
76
+ with self .assertNumQueries (0 ):
77
+ self .assertIsNone (obj2 .null )
77
78
78
79
@skipUnlessDBFeature ("supports_expression_defaults" )
79
80
@override_settings (USE_TZ = True )
80
81
def test_db_default_function (self ):
81
82
m = DBDefaultsFunction .objects .create ()
82
- if not connection .features .can_return_columns_from_insert :
83
- m .refresh_from_db ()
84
- self .assertAlmostEqual (m .number , pi )
85
- self .assertEqual (m .year , timezone .now ().year )
86
- self .assertAlmostEqual (m .added , pi + 4.5 )
87
- self .assertEqual (m .multiple_subfunctions , 4.5 )
83
+ expected_num_queries = (
84
+ 0 if connection .features .can_return_columns_from_insert else 4
85
+ )
86
+ with self .assertNumQueries (expected_num_queries ):
87
+ self .assertAlmostEqual (m .number , pi )
88
+ self .assertEqual (m .year , timezone .now ().year )
89
+ self .assertAlmostEqual (m .added , pi + 4.5 )
90
+ self .assertEqual (m .multiple_subfunctions , 4.5 )
88
91
89
92
@skipUnlessDBFeature ("insert_test_table_with_defaults" )
90
93
def test_both_default (self ):
@@ -125,14 +128,15 @@ def test_foreign_key_db_default(self):
125
128
child2 = DBDefaultsFK .objects .create (language_code = parent2 )
126
129
self .assertEqual (child2 .language_code , parent2 )
127
130
128
- @skipUnlessDBFeature (
129
- "can_return_columns_from_insert" , "supports_expression_defaults"
130
- )
131
+ @skipUnlessDBFeature ("supports_expression_defaults" )
131
132
def test_case_when_db_default_returning (self ):
132
133
m = DBDefaultsFunction .objects .create ()
133
- self .assertEqual (m .case_when , 3 )
134
+ expected_num_queries = (
135
+ 0 if connection .features .can_return_columns_from_insert else 1
136
+ )
137
+ with self .assertNumQueries (expected_num_queries ):
138
+ self .assertEqual (m .case_when , 3 )
134
139
135
- @skipIfDBFeature ("can_return_columns_from_insert" )
136
140
@skipUnlessDBFeature ("supports_expression_defaults" )
137
141
def test_case_when_db_default_no_returning (self ):
138
142
m = DBDefaultsFunction .objects .create ()
0 commit comments