@@ -28,7 +28,48 @@ struct C10_API PyObjectSlot {
28
28
PyInterpreter* self_interpreter,
29
29
PyObject* pyobj,
30
30
PyInterpreterStatus status) {
31
- pyobj_interpreter_.store (self_interpreter, std::memory_order_relaxed);
31
+ impl::PyInterpreter* expected = nullptr ;
32
+ switch (status) {
33
+ case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
34
+ // caller guarantees there is no multithreaded access; if there is
35
+ // no data race OK to do a relaxed store
36
+ pyobj_interpreter_.store (self_interpreter, std::memory_order_relaxed);
37
+ break ;
38
+ case impl::PyInterpreterStatus::TAGGED_BY_US:
39
+ // no tagging is necessary, the tag is already correct
40
+ break ;
41
+ case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
42
+ // attempt to claim this TensorImpl with the specified interpreter
43
+ // tag
44
+ if (pyobj_interpreter_.compare_exchange_strong (
45
+ expected, self_interpreter, std::memory_order_acq_rel)) {
46
+ break ;
47
+ }
48
+ // test if, actually, it was already tagged by us! this situation can't
49
+ // be caused by a race, but it could be caused by a situation
50
+ // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
51
+ // (because they didn't pre-check the tag) when actually it was
52
+ // owned by the interpreter
53
+ if (expected == self_interpreter) {
54
+ break ;
55
+ }
56
+ // fallthrough, we lost the race. We are guaranteed not to lose the
57
+ // race with ourself, as calls to init_pyobj with the same interpreter
58
+ // ID must be sequentialized by the GIL
59
+ [[fallthrough]];
60
+ case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
61
+ TORCH_CHECK (
62
+ false ,
63
+ " cannot allocate PyObject for Tensor on interpreter " ,
64
+ self_interpreter,
65
+ " that has already been used by another torch deploy interpreter " ,
66
+ pyobj_interpreter_.load ());
67
+ }
68
+
69
+ // we are the ONLY thread that can have gotten to this point. It is not
70
+ // possible to conflict with another zero interpreter as access is protected
71
+ // by GIL
72
+ // NB: owns_pyobj tag is initially false
32
73
pyobj_ = pyobj;
33
74
}
34
75
@@ -56,16 +97,30 @@ struct C10_API PyObjectSlot {
56
97
std::optional<PyObject*> check_pyobj (
57
98
PyInterpreter* self_interpreter,
58
99
bool ignore_hermetic_tls = false ) const {
100
+ // Note [Memory ordering on Python interpreter tag]
59
101
impl::PyInterpreter* interpreter =
60
102
pyobj_interpreter_.load (std::memory_order_acquire);
61
103
if (interpreter == nullptr ) {
104
+ // NB: This never returns DEFINITELY_UNINITIALIZED because there is
105
+ // always the possibility that another thread races to initialize
106
+ // after we query here. The only time when we can conclude a tensor
107
+ // is definitely uninitialized is when we have just allocated it and
108
+ // it cannot have escaped to other threads yet
62
109
return std::nullopt ;
63
- }
64
-
65
- if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state ()) {
66
- return std::nullopt ;
110
+ } else if (interpreter == self_interpreter) {
111
+ // NB: pyobj_ could still be null!
112
+ if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state ()) {
113
+ return std::nullopt ;
114
+ } else {
115
+ return _unchecked_untagged_pyobj ();
116
+ }
67
117
} else {
68
- return _unchecked_untagged_pyobj ();
118
+ TORCH_CHECK (
119
+ false ,
120
+ " cannot access PyObject for Tensor on interpreter " ,
121
+ (*self_interpreter)->name (),
122
+ " that has already been used by another torch deploy interpreter " ,
123
+ (*pyobj_interpreter_.load ())->name ());
69
124
}
70
125
}
71
126
@@ -75,6 +130,13 @@ struct C10_API PyObjectSlot {
75
130
76
131
PyInterpreter& load_pyobj_interpreter () const ;
77
132
133
+ // Check if the PyObjectSlot's interpreter is the same as the specified
134
+ // interpreter
135
+ bool check_interpreter (PyInterpreter* interpreter);
136
+
137
+ // Check if the PyObjectSlot is holding a PyObject, owned or non-owned
138
+ bool has_pyobj_nonhermetic ();
139
+
78
140
bool owns_pyobj ();
79
141
80
142
void set_owns_pyobj (bool b);
0 commit comments