Skip to content

Commit c1f5005

Browse files
authored
fix: finish support of type assertions which was incomplete (#657)
* fix: finish support of type assertions which was incomplete TypeAssert was optimistically returning ok without verifying that value could be converted to the required interface (in case of type assert of an interface type), or not checking the type in all conditions. There is now a working implements method for itype. Fixes #640. * style: appease lint * fix: remove useless code block * doc: improve comments * avoid test conflict
1 parent def57d5 commit c1f5005

File tree

5 files changed

+172
-45
lines changed

5 files changed

+172
-45
lines changed

_test/method33.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
type T1 struct{}
8+
9+
func (t1 T1) f() {
10+
fmt.Println("T1.f()")
11+
}
12+
13+
func (t1 T1) g() {
14+
fmt.Println("T1.g()")
15+
}
16+
17+
type T2 struct {
18+
T1
19+
}
20+
21+
func (t2 T2) f() {
22+
fmt.Println("T2.f()")
23+
}
24+
25+
type I interface {
26+
f()
27+
}
28+
29+
func printType(i I) {
30+
if t1, ok := i.(T1); ok {
31+
println("T1 ok")
32+
t1.f()
33+
t1.g()
34+
}
35+
36+
if t2, ok := i.(T2); ok {
37+
println("T2 ok")
38+
t2.f()
39+
t2.g()
40+
}
41+
}
42+
43+
func main() {
44+
println("T1")
45+
printType(T1{})
46+
println("T2")
47+
printType(T2{})
48+
}
49+
50+
// Output:
51+
// T1
52+
// T1 ok
53+
// T1.f()
54+
// T1.g()
55+
// T2
56+
// T2 ok
57+
// T2.f()
58+
// T1.g()

_test/method34.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package main
2+
3+
type Root struct {
4+
Name string
5+
}
6+
7+
type One struct {
8+
Root
9+
}
10+
11+
type Hi interface {
12+
Hello() string
13+
}
14+
15+
func (r *Root) Hello() string { return "Hello " + r.Name }
16+
17+
func main() {
18+
var one interface{} = &One{Root{Name: "test2"}}
19+
println(one.(Hi).Hello())
20+
}
21+
22+
// Output:
23+
// Hello test2

interp/gta.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ func (interp *Interpreter) gta(root *node, rpath, pkgID string) ([]*node, error)
137137
}
138138
}
139139
rcvrtype.method = append(rcvrtype.method, n)
140+
n.child[0].child[0].lastChild().typ = rcvrtype
140141
} else {
141142
// Add a function symbol in the package name space
142143
sc.sym[n.child[1].ident] = &symbol{kind: funcSym, typ: n.typ, node: n, index: -1}

interp/run.go

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -122,34 +122,38 @@ func runCfg(n *node, f *frame) {
122122
}
123123

124124
func typeAssertStatus(n *node) {
125-
c0, c1 := n.child[0], n.child[1]
125+
c0, c1 := n.child[0], n.child[1] // cO contains the input value, c1 the type to assert
126126
value := genValue(c0) // input value
127127
value1 := genValue(n.anc.child[1]) // returned status
128-
typ := c1.typ.rtype // type to assert
129128
next := getExec(n.tnext)
130129

131130
switch {
132-
case c0.typ.cat == valueT:
131+
case isInterfaceSrc(c1.typ):
132+
typ := c1.typ
133+
n.exec = func(f *frame) bltn {
134+
v, ok := value(f).Interface().(valueInterface)
135+
value1(f).SetBool(ok && v.node.typ.implements(typ))
136+
return next
137+
}
138+
case isInterface(c1.typ):
139+
rtype := c1.typ.rtype
133140
n.exec = func(f *frame) bltn {
134141
v := value(f)
135-
if !v.IsValid() || v.IsNil() {
136-
value1(f).SetBool(false)
137-
}
138-
value1(f).SetBool(v.Type().Implements(typ))
142+
value1(f).SetBool(v.IsValid() && v.Type().Implements(rtype))
139143
return next
140144
}
141-
case c1.typ.cat == interfaceT:
145+
case c0.typ.cat == valueT:
146+
rtype := c1.typ.rtype
142147
n.exec = func(f *frame) bltn {
143-
_, ok := value(f).Interface().(valueInterface)
144-
// TODO: verify that value(f) implements asserted type.
145-
value1(f).SetBool(ok)
148+
v := value(f)
149+
value1(f).SetBool(v.IsValid() && v.Type() == rtype)
146150
return next
147151
}
148152
default:
153+
typID := c1.typ.id()
149154
n.exec = func(f *frame) bltn {
150-
_, ok := value(f).Interface().(valueInterface)
151-
// TODO: verify that value(f) implements asserted type.
152-
value1(f).SetBool(ok)
155+
v, ok := value(f).Interface().(valueInterface)
156+
value1(f).SetBool(ok && v.node.typ.id() == typID)
153157
return next
154158
}
155159
}
@@ -162,24 +166,35 @@ func typeAssert(n *node) {
162166
next := getExec(n.tnext)
163167

164168
switch {
165-
case c0.typ.cat == valueT:
169+
case isInterfaceSrc(c1.typ):
170+
typ := n.child[1].typ
171+
typID := n.child[1].typ.id()
166172
n.exec = func(f *frame) bltn {
167173
v := value(f)
168-
dest(f).Set(v.Elem())
174+
vi, ok := v.Interface().(valueInterface)
175+
if !ok {
176+
panic(n.cfgErrorf("interface conversion: nil is not %v", typID))
177+
}
178+
if !vi.node.typ.implements(typ) {
179+
panic(n.cfgErrorf("interface conversion: %v is not %v", vi.node.typ.id(), typID))
180+
}
181+
dest(f).Set(v)
169182
return next
170183
}
171-
case c1.typ.cat == interfaceT:
184+
case isInterface(c1.typ):
185+
rtype := n.child[1].typ.rtype
186+
n.exec = func(f *frame) bltn {
187+
dest(f).Set(value(f).Convert(rtype))
188+
return next
189+
}
190+
case c0.typ.cat == valueT:
172191
n.exec = func(f *frame) bltn {
173-
v := value(f).Interface().(valueInterface)
174-
// TODO: verify that value(f) implements asserted type.
175-
dest(f).Set(reflect.ValueOf(valueInterface{v.node, v.value}))
192+
dest(f).Set(value(f).Elem())
176193
return next
177194
}
178195
default:
179196
n.exec = func(f *frame) bltn {
180-
v := value(f).Interface().(valueInterface)
181-
// TODO: verify that value(f) implements asserted type.
182-
dest(f).Set(v.value)
197+
dest(f).Set(value(f).Interface().(valueInterface).value)
183198
return next
184199
}
185200
}
@@ -189,30 +204,58 @@ func typeAssert2(n *node) {
189204
value := genValue(n.child[0]) // input value
190205
value0 := genValue(n.anc.child[0]) // returned result
191206
value1 := genValue(n.anc.child[1]) // returned status
207+
typ := n.child[1].typ // type to assert or convert to
208+
typID := typ.id()
192209
next := getExec(n.tnext)
193210

194211
switch {
195-
case n.child[0].typ.cat == valueT:
212+
case isInterfaceSrc(typ):
196213
n.exec = func(f *frame) bltn {
197-
if value(f).IsValid() && !value(f).IsNil() {
198-
value0(f).Set(value(f).Elem())
214+
v, ok := value(f).Interface().(valueInterface)
215+
if ok && v.node.typ.id() == typID {
216+
value0(f).Set(value(f))
217+
} else {
218+
ok = false
199219
}
200-
value1(f).SetBool(true)
220+
value1(f).SetBool(ok)
201221
return next
202222
}
203-
case n.child[1].typ.cat == interfaceT:
223+
case isInterface(typ):
224+
rtype := typ.rtype
204225
n.exec = func(f *frame) bltn {
205-
v, ok := value(f).Interface().(valueInterface)
206-
// TODO: verify that value(f) implements asserted type.
207-
value0(f).Set(reflect.ValueOf(valueInterface{v.node, v.value}))
226+
v := value(f)
227+
ok := v.IsValid() && v.Type().Implements(rtype)
228+
if ok {
229+
value0(f).Set(v.Convert(rtype))
230+
}
231+
value1(f).SetBool(ok)
232+
return next
233+
}
234+
case n.child[0].typ.cat == valueT:
235+
rtype := n.child[1].typ.rtype
236+
n.exec = func(f *frame) bltn {
237+
v := value(f)
238+
ok := v.IsValid() && !value(f).IsNil()
239+
if ok {
240+
if e := v.Elem(); e.Type() == rtype {
241+
value0(f).Set(e)
242+
} else {
243+
ok = false
244+
}
245+
}
208246
value1(f).SetBool(ok)
209247
return next
210248
}
211249
default:
212250
n.exec = func(f *frame) bltn {
213251
v, ok := value(f).Interface().(valueInterface)
214-
// TODO: verify that value(f) implements asserted type.
215-
value0(f).Set(v.value)
252+
if ok {
253+
if v.node.typ.id() == typID {
254+
value0(f).Set(v.value)
255+
} else {
256+
ok = false
257+
}
258+
}
216259
value1(f).SetBool(ok)
217260
return next
218261
}

interp/type.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ func (t *itype) methods() methodSet {
818818
res := make(methodSet)
819819
switch t.cat {
820820
case interfaceT:
821-
// Get methods from recursive analysis of interface fields
821+
// Get methods from recursive analysis of interface fields.
822822
for _, f := range t.field {
823823
if f.typ.cat == funcT {
824824
res[f.name] = f.typ.TypeOf().String()
@@ -829,23 +829,26 @@ func (t *itype) methods() methodSet {
829829
}
830830
}
831831
case valueT, errorT:
832-
// Get method from corresponding reflect.Type
832+
// Get method from corresponding reflect.Type.
833833
for i := t.rtype.NumMethod() - 1; i >= 0; i-- {
834834
m := t.rtype.Method(i)
835835
res[m.Name] = m.Type.String()
836836
}
837837
case ptrT:
838-
// Consider only methods where receiver is a pointer to type t
839-
for _, m := range t.val.method {
840-
if m.child[0].child[0].lastChild().typ.cat == ptrT {
841-
res[m.ident] = m.typ.TypeOf().String()
842-
}
838+
for k, v := range t.val.methods() {
839+
res[k] = v
843840
}
844-
default:
845-
for _, m := range t.method {
846-
res[m.ident] = m.typ.TypeOf().String()
841+
case structT:
842+
for _, f := range t.field {
843+
for k, v := range f.typ.methods() {
844+
res[k] = v
845+
}
847846
}
848847
}
848+
// Get all methods defined on this type.
849+
for _, m := range t.method {
850+
res[m.ident] = m.typ.TypeOf().String()
851+
}
849852
return res
850853
}
851854

@@ -1192,8 +1195,7 @@ func (t *itype) implements(it *itype) bool {
11921195
if t.cat == valueT {
11931196
return t.TypeOf().Implements(it.TypeOf())
11941197
}
1195-
// TODO: implement method check for interpreted types
1196-
return true
1198+
return t.methods().contains(it.methods())
11971199
}
11981200

11991201
func defRecvType(n *node) *itype {

0 commit comments

Comments
 (0)