Skip to content

Commit 5cbbf93

Browse files
authored
fix: type assertion of non-matching types
1 parent 8365f68 commit 5cbbf93

File tree

3 files changed

+159
-31
lines changed

3 files changed

+159
-31
lines changed

_test/type23.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package main
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
)
7+
8+
func main() {
9+
var v1 interface{} = 1
10+
var v2 interface{}
11+
var v3 http.ResponseWriter = httptest.NewRecorder()
12+
13+
if r1, ok := v1.(string); ok {
14+
_ = r1
15+
println("unexpected")
16+
}
17+
if _, ok := v1.(string); ok {
18+
println("unexpected")
19+
}
20+
if r2, ok := v2.(string); ok {
21+
_ = r2
22+
println("unexpected")
23+
}
24+
if _, ok := v2.(string); ok {
25+
println("unexpected")
26+
}
27+
if r3, ok := v3.(http.Pusher); ok {
28+
_ = r3
29+
println("unexpected")
30+
}
31+
if _, ok := v3.(http.Pusher); ok {
32+
println("unexpected")
33+
}
34+
println("bye")
35+
}
36+
37+
// Output:
38+
// bye

_test/type24.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
)
8+
9+
func main() {
10+
assertInt()
11+
assertNil()
12+
assertValue()
13+
}
14+
15+
func assertInt() {
16+
defer func() {
17+
r := recover()
18+
fmt.Println(r)
19+
}()
20+
21+
var v interface{} = 1
22+
println(v.(string))
23+
}
24+
25+
func assertNil() {
26+
defer func() {
27+
r := recover()
28+
fmt.Println(r)
29+
}()
30+
31+
var v interface{}
32+
println(v.(string))
33+
}
34+
35+
func assertValue() {
36+
defer func() {
37+
r := recover()
38+
fmt.Println(r)
39+
}()
40+
41+
var v http.ResponseWriter = httptest.NewRecorder()
42+
println(v.(http.Pusher))
43+
}
44+
45+
// Output:
46+
// interface conversion: interface {} is int, not string
47+
// interface conversion: interface {} is nil, not string
48+
// interface conversion: *httptest.ResponseRecorder is not http.Pusher: missing method Push

interp/run.go

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ func typeAssertStatus(n *node) {
125125
c0, c1 := n.child[0], n.child[1] // cO contains the input value, c1 the type to assert
126126
value := genValue(c0) // input value
127127
value1 := genValue(n.anc.child[1]) // returned status
128+
rtype := c1.typ.rtype // type to assert
128129
next := getExec(n.tnext)
129130

130131
switch {
@@ -136,24 +137,24 @@ func typeAssertStatus(n *node) {
136137
return next
137138
}
138139
case isInterface(c1.typ):
139-
rtype := c1.typ.rtype
140140
n.exec = func(f *frame) bltn {
141141
v := value(f)
142-
value1(f).SetBool(v.IsValid() && v.Type().Implements(rtype))
142+
ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype)
143+
value1(f).SetBool(ok)
143144
return next
144145
}
145146
case c0.typ.cat == valueT:
146-
rtype := c1.typ.rtype
147147
n.exec = func(f *frame) bltn {
148148
v := value(f)
149-
value1(f).SetBool(v.IsValid() && v.Type() == rtype)
149+
ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype)
150+
value1(f).SetBool(ok)
150151
return next
151152
}
152153
default:
153-
typID := c1.typ.id()
154154
n.exec = func(f *frame) bltn {
155155
v, ok := value(f).Interface().(valueInterface)
156-
value1(f).SetBool(ok && v.node.typ.id() == typID)
156+
ok = ok && v.value.IsValid() && canAssertTypes(v.value.Type(), rtype)
157+
value1(f).SetBool(ok)
157158
return next
158159
}
159160
}
@@ -162,7 +163,7 @@ func typeAssertStatus(n *node) {
162163
func typeAssert(n *node) {
163164
c0, c1 := n.child[0], n.child[1]
164165
value := genValue(c0) // input value
165-
dest := genValue(n) // returned result
166+
value0 := genValue(n) // returned result
166167
next := getExec(n.tnext)
167168

168169
switch {
@@ -178,34 +179,61 @@ func typeAssert(n *node) {
178179
if !vi.node.typ.implements(typ) {
179180
panic(n.cfgErrorf("interface conversion: %v is not %v", vi.node.typ.id(), typID))
180181
}
181-
dest(f).Set(v)
182+
value0(f).Set(v)
182183
return next
183184
}
184185
case isInterface(c1.typ):
185-
rtype := n.child[1].typ.rtype
186186
n.exec = func(f *frame) bltn {
187-
dest(f).Set(value(f).Convert(rtype))
187+
v := value(f).Elem()
188+
typ := value0(f).Type()
189+
if !v.IsValid() {
190+
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
191+
}
192+
if !canAssertTypes(v.Type(), typ) {
193+
method := firstMissingMethod(v.Type(), typ)
194+
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", v.Type().String(), typ.String(), method))
195+
}
196+
value0(f).Set(v)
188197
return next
189198
}
190199
case c0.typ.cat == valueT:
191200
n.exec = func(f *frame) bltn {
192-
dest(f).Set(value(f).Elem())
201+
v := value(f).Elem()
202+
typ := value0(f).Type()
203+
if !v.IsValid() {
204+
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
205+
}
206+
if !canAssertTypes(v.Type(), typ) {
207+
method := firstMissingMethod(v.Type(), typ)
208+
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", v.Type().String(), typ.String(), method))
209+
}
210+
value0(f).Set(v)
193211
return next
194212
}
195213
default:
196214
n.exec = func(f *frame) bltn {
197-
dest(f).Set(value(f).Interface().(valueInterface).value)
215+
v := value(f).Interface().(valueInterface)
216+
typ := value0(f).Type()
217+
if !v.value.IsValid() {
218+
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
219+
}
220+
if !canAssertTypes(v.value.Type(), typ) {
221+
panic(fmt.Sprintf("interface conversion: interface {} is %s, not %s", v.value.Type().String(), typ.String()))
222+
}
223+
value0(f).Set(v.value)
198224
return next
199225
}
200226
}
201227
}
202228

203229
func typeAssert2(n *node) {
204-
value := genValue(n.child[0]) // input value
230+
c0, c1 := n.child[0], n.child[1]
231+
value := genValue(c0) // input value
205232
value0 := genValue(n.anc.child[0]) // returned result
206233
value1 := genValue(n.anc.child[1]) // returned status
207-
typ := n.child[1].typ // type to assert or convert to
234+
typ := c1.typ // type to assert or convert to
208235
typID := typ.id()
236+
rtype := typ.rtype // type to assert
209237
next := getExec(n.tnext)
210238

211239
switch {
@@ -221,47 +249,61 @@ func typeAssert2(n *node) {
221249
return next
222250
}
223251
case isInterface(typ):
224-
rtype := typ.rtype
225252
n.exec = func(f *frame) bltn {
226-
v := value(f)
227-
ok := v.IsValid() && v.Type().Implements(rtype)
253+
v := value(f).Elem()
254+
ok := v.IsValid() && canAssertTypes(v.Type(), rtype)
228255
if ok {
229-
value0(f).Set(v.Convert(rtype))
256+
value0(f).Set(v)
230257
}
231258
value1(f).SetBool(ok)
232259
return next
233260
}
234261
case n.child[0].typ.cat == valueT:
235-
rtype := n.child[1].typ.rtype
236262
n.exec = func(f *frame) bltn {
237-
v := value(f)
238-
ok := v.IsValid() && !value(f).IsNil()
263+
v := value(f).Elem()
264+
ok := v.IsValid() && canAssertTypes(v.Type(), rtype)
239265
if ok {
240-
if e := v.Elem(); e.Type() == rtype {
241-
value0(f).Set(e)
242-
} else {
243-
ok = false
244-
}
266+
value0(f).Set(v)
245267
}
246268
value1(f).SetBool(ok)
247269
return next
248270
}
249271
default:
250272
n.exec = func(f *frame) bltn {
251273
v, ok := value(f).Interface().(valueInterface)
274+
ok = ok && v.value.IsValid() && canAssertTypes(v.value.Type(), rtype)
252275
if ok {
253-
if v.node.typ.id() == typID {
254-
value0(f).Set(v.value)
255-
} else {
256-
ok = false
257-
}
276+
value0(f).Set(v.value)
258277
}
259278
value1(f).SetBool(ok)
260279
return next
261280
}
262281
}
263282
}
264283

284+
func canAssertTypes(src, dest reflect.Type) bool {
285+
if src == dest {
286+
return true
287+
}
288+
if dest.Kind() == reflect.Interface && src.Implements(dest) {
289+
return true
290+
}
291+
if src.AssignableTo(dest) {
292+
return true
293+
}
294+
return false
295+
}
296+
297+
func firstMissingMethod(src, dest reflect.Type) string {
298+
for i := 0; i < dest.NumMethod(); i++ {
299+
m := dest.Method(i).Name
300+
if _, ok := src.MethodByName(m); !ok {
301+
return m
302+
}
303+
}
304+
return ""
305+
}
306+
265307
func convert(n *node) {
266308
dest := genValue(n)
267309
c := n.child[1]

0 commit comments

Comments
 (0)