diff --git a/ast.go b/ast.go index 084cc36d..444fd554 100644 --- a/ast.go +++ b/ast.go @@ -175,6 +175,20 @@ func toBasicType(pkg *Package, t *types.Basic) ast.Expr { return &ast.Ident{Name: t.Name()} } +func isBasicUntyped(typ types.Type) bool { + if t, ok := typ.(*types.Basic); ok { + return (t.Info() & types.IsUntyped) != 0 + } + return false +} + +func isBasicUntypedKind(typ types.Type) (bool, types.BasicKind) { + if t, ok := typ.(*types.Basic); ok { + return (t.Info() & types.IsUntyped) != 0, t.Kind() + } + return false, types.Invalid +} + func isUntyped(pkg *Package, typ types.Type) bool { switch t := typ.(type) { case *types.Basic: @@ -682,6 +696,7 @@ retry: cval = unaryOp(pkg, t.tok(), args) } else if t.isOp() { cval = binaryOp(&pkg.cb, t.tok(), args) + pkg.cb.recordUpdateUntypedBinaryOp(t.tok(), args, nil) } else if t.hasApproxType() { flags |= instrFlagApproxType } @@ -707,6 +722,7 @@ retry: switch t := fn.Val.(type) { case *ast.BinaryExpr: t.X, t.Y = checkParenExpr(args[0].Val), checkParenExpr(args[1].Val) + pkg.cb.recordUpdateUntypedBinaryOp(t.Op, args, tyRet) return &internal.Elem{Val: t, Type: tyRet, CVal: cval}, nil case *ast.UnaryExpr: t.X = args[0].Val @@ -734,6 +750,7 @@ func matchTypeCast(pkg *Package, typ types.Type, fn *internal.Elem, args []*inte fnVal = &ast.ParenExpr{X: fnVal} } if len(args) == 1 && ConvertibleTo(pkg, args[0].Type, typ) { + pkg.cb.recordUpdateUntyped(args[0], typ) if args[0].CVal != nil { if t, ok := typ.(*types.Named); ok { o := t.Obj() @@ -1020,9 +1037,11 @@ func matchFuncType( func matchFuncArgs( pkg *Package, args []*internal.Elem, sig *types.Signature, at interface{}) error { for i, arg := range args { - if err := matchType(pkg, arg, getParam(sig, i).Type(), at); err != nil { + typ := getParam(sig, i).Type() + if err := matchType(pkg, arg, typ, at); err != nil { return err } + pkg.cb.recordUpdateUntyped(arg, typ) } return nil } @@ -1060,9 +1079,11 @@ func checkFuncResults(pkg *Package, rets []*internal.Elem, results *types.Tuple, } if n == need { for i := 0; i < need; i++ { - if err := matchType(pkg, rets[i], results.At(i).Type(), "return argument"); err != nil { + typ := results.At(i).Type() + if err := matchType(pkg, rets[i], typ, "return argument"); err != nil { panic(err) } + pkg.cb.recordUpdateUntyped(rets[i], typ) } return } @@ -1100,6 +1121,7 @@ func matchElemType(pkg *Package, vals []*internal.Elem, elt types.Type, at inter if err := matchType(pkg, val, elt, at); err != nil { return err } + pkg.cb.recordUpdateUntyped(val, elt) } return nil } diff --git a/codebuild.go b/codebuild.go index ece39978..0583d7aa 100644 --- a/codebuild.go +++ b/codebuild.go @@ -828,6 +828,8 @@ func (p *CodeBuilder) MapLit(typ types.Type, arity int, src ...ast.Node) *CodeBu pos, "cannot use %s (type %v) as type %v in map value", src, args[i+1].Type, val) } } + p.recordUpdateUntyped(args[i], key) + p.recordUpdateUntyped(args[i+1], val) } p.stk.Ret(arity, &internal.Elem{ Type: typ, Val: &ast.CompositeLit{Type: typExpr, Elts: elts}, Src: getSrc(src), @@ -876,6 +878,7 @@ func (p *CodeBuilder) indexElemExpr(args []*internal.Elem, i int) ast.Expr { return args[i+1].Val } p.toIntVal(args[i], "index which must be non-negative integer constant") + p.recordUpdateUntyped(args[i], types.Typ[types.Int]) return &ast.KeyValueExpr{Key: key, Value: args[i+1].Val} } @@ -921,6 +924,7 @@ func (p *CodeBuilder) SliceLitEx(typ types.Type, arity int, keyVal bool, src ... p.panicCodeErrorf( pos, "cannot use %s (type %v) as type %v in slice literal", src, args[i+1].Type, val) } + p.recordUpdateUntyped(arg, val) elts[i>>1] = p.indexElemExpr(args, i) } } else { @@ -956,6 +960,7 @@ func (p *CodeBuilder) SliceLitEx(typ types.Type, arity int, keyVal bool, src ... pos, "cannot use %s (type %v) as type %v in slice literal", src, arg.Type, val) } } + p.recordUpdateUntyped(arg, val) } } p.stk.Ret(arity, &internal.Elem{ @@ -1008,6 +1013,7 @@ func (p *CodeBuilder) ArrayLitEx(typ types.Type, arity int, keyVal bool, src ... p.panicCodeErrorf( pos, "cannot use %s (type %v) as type %v in array literal", src, args[i+1].Type, val) } + p.recordUpdateUntyped(args[i+1], val) elts[i>>1] = p.indexElemExpr(args, i) } } else { @@ -1028,6 +1034,7 @@ func (p *CodeBuilder) ArrayLitEx(typ types.Type, arity int, keyVal bool, src ... p.panicCodeErrorf( pos, "cannot use %s (type %v) as type %v in array literal", src, arg.Type, val) } + p.recordUpdateUntyped(arg, val) } } p.stk.Ret(arity, &internal.Elem{ @@ -1075,6 +1082,7 @@ func (p *CodeBuilder) StructLit(typ types.Type, arity int, keyVal bool, src ...a pos, "cannot use %s (type %v) as type %v in value of field %s", src, args[i+1].Type, eltTy, eltName) } + pkg.cb.recordUpdateUntyped(args[i+1], eltTy) elts[i>>1] = &ast.KeyValueExpr{Key: ident(eltName), Value: args[i+1].Val} } } else if arity != n { @@ -1097,6 +1105,7 @@ func (p *CodeBuilder) StructLit(typ types.Type, arity int, keyVal bool, src ...a pos, "cannot use %s (type %v) as type %v in value of field %s", src, arg.Type, eltTy, t.Field(i).Name()) } + pkg.cb.recordUpdateUntyped(arg, eltTy) } } p.stk.Ret(arity, &internal.Elem{ @@ -1145,6 +1154,13 @@ func (p *CodeBuilder) Slice(slice3 bool, src ...ast.Node) *CodeBuilder { // a[i: if slice3 { exprMax = args[3].Val } + + p.recordUpdateUntyped(args[1], tyInt) + p.recordUpdateUntyped(args[2], tyInt) + if slice3 { + p.recordUpdateUntyped(args[3], tyInt) + } + // TODO: check type elem := &internal.Elem{ Val: &ast.SliceExpr{ @@ -1185,6 +1201,7 @@ func (p *CodeBuilder) Index(nidx int, twoValue bool, src ...ast.Node) *CodeBuild } else { // elem = a[key] tyRet = typs[1] } + p.recordUpdateUntyped(args[1], tyInt) elem := &internal.Elem{ Val: &ast.IndexExpr{X: args[0].Val, Index: args[1].Val}, Type: tyRet, Src: srcExpr, } @@ -1203,6 +1220,7 @@ func (p *CodeBuilder) IndexRef(nidx int, src ...ast.Node) *CodeBuilder { } args := p.stk.GetArgs(2) typ := args[0].Type + p.recordUpdateUntyped(args[1], tyInt) elemRef := &internal.Elem{ Val: &ast.IndexExpr{X: args[0].Val, Index: args[1].Val}, Src: getSrc(src), @@ -1988,6 +2006,7 @@ func (p *CodeBuilder) doAssignWith(lhs, rhs int, src ast.Node) *CodeBuilder { lhsType = &refType{typ: bfr.typ} } checkAssignType(p.pkg, lhsType, args[lhs+i]) + p.recordUpdateUntyped(args[lhs+i], lhsType) stmt.Lhs[i] = args[i].Val stmt.Rhs[i] = args[lhs+i].Val if bfAssign { @@ -2054,6 +2073,7 @@ retry: if !ComparableTo(pkg, args[0], args[1]) { return nil, errors.New("mismatched types") } + cb.recordUpdateUntypedBinaryOp(op, args, nil) ret = &internal.Elem{ Val: &ast.BinaryExpr{ X: checkParenExpr(args[0].Val), Op: op, @@ -2169,6 +2189,9 @@ func (p *CodeBuilder) Send() *CodeBuilder { val := p.stk.Pop() ch := p.stk.Pop() // TODO: check types + if typ, ok := ch.Type.(*types.Chan); ok { + p.recordUpdateUntyped(val, typ.Elem()) + } p.emitStmt(&ast.SendStmt{Chan: ch.Val, Value: val.Val}) return p } @@ -2637,3 +2660,64 @@ func (p *CodeBuilder) InternalStack() *InternalStack { } // ---------------------------------------------------------------------------- + +func (p *CodeBuilder) recordUpdateUntyped(e *internal.Elem, param types.Type) { + if p.rec != nil && isBasicUntyped(e.Type) { + typ := param + retry: + switch t := typ.(type) { + case *unboundFuncParam: + typ = t.tBound + goto retry + case *refType: + typ = t.typ + goto retry + case *types.Basic: + if e.Type == t { + return + } + case *types.Named: + if t.Underlying() == TyEmptyInterface { + typ = types.Default(e.Type) + } + case *types.Interface: + if t == TyEmptyInterface { + typ = types.Default(e.Type) + } + } + p.rec.UpdateUntyped(e, typ) + } +} + +func (p *CodeBuilder) recordUpdateUntypedBinaryOp(tok token.Token, args []*internal.Elem, tyRet types.Type) { + if p.rec == nil { + return + } + kind := binaryOpKinds[tok] + if kind == binaryOpCompare { + b0, k0 := isBasicUntypedKind(args[0].Type) + b1, k1 := isBasicUntypedKind(args[1].Type) + if b0 && !b1 { + p.rec.UpdateUntyped(args[0], args[1].Type) + } else if !b0 && b1 { + p.rec.UpdateUntyped(args[1], args[0].Type) + } else if b0 && b1 && k0 != k1 { + // UntypedInt + // UntypedRune + // UntypedFloat + // UntypedComplex + if k0 < k1 { + p.rec.UpdateUntyped(args[0], args[1].Type) + } else { + p.rec.UpdateUntyped(args[1], args[0].Type) + } + } + } else if tyRet != nil { + if isBasicUntyped(args[0].Type) && tyRet != args[0].Type { + p.rec.UpdateUntyped(args[0], tyRet) + } + if kind != binaryOpShift && tyRet != args[1].Type { + p.rec.UpdateUntyped(args[1], tyRet) + } + } +} diff --git a/package.go b/package.go index 861e05b4..b88aba01 100644 --- a/package.go +++ b/package.go @@ -73,6 +73,7 @@ func fatal(msg string) { type Recorder interface { // Member maps identifiers to the objects they denote. Member(id ast.Node, obj types.Object) + UpdateUntyped(e *Element, typ types.Type) } // ---------------------------------------------------------------------------- diff --git a/package_test.go b/package_test.go index b10031d4..042ba543 100644 --- a/package_test.go +++ b/package_test.go @@ -52,6 +52,8 @@ type eventRecorder struct{} func (p eventRecorder) Member(id ast.Node, obj types.Object) { } +func (p eventRecorder) UpdateUntyped(e *gox.Element, typ types.Type) { +} func newMainPackage( implicitCast ...func(pkg *gox.Package, V, T types.Type, pv *gox.Element) bool) *gox.Package { @@ -3436,4 +3438,41 @@ func main() { `) } +func TestVarBinary(t *testing.T) { + pkg := newMainPackage() + pkg.CB().NewVarStart(types.Typ[types.Int], "a"). + Val(1). + EndInit(1). + NewVarStart(types.Typ[types.Int], "b"). + Val(4).Val(pkg.Ref("a")).BinaryOp(token.ADD). + EndInit(1). + NewVarStart(types.Typ[types.Bool], "c"). + Val(1).Val(1.0).BinaryOp(token.EQL).EndInit(1). + NewVarStart(types.Typ[types.Bool], "d"). + Val(1.0).Val(1).BinaryOp(token.NEQ).EndInit(1) + domTest(t, pkg, `package main + +var a int = 1 +var b int = 4 + a +var c bool = 1 == 1.0 +var d bool = 1.0 != 1 +`) +} + +func TestVarEmptyInterface(t *testing.T) { + pkg := newMainPackage() + named := pkg.NewType("T").InitType(pkg, gox.TyEmptyInterface) + pkg.CB().NewVarStart(named, "a").Val(1).EndInit(1). + NewVarStart(gox.TyEmptyInterface, "b").Val(2).EndInit(1) + domTest(t, pkg, `package main + +type T interface { +} + +var a T = 1 +var b interface { +} = 2 +`) +} + // ---------------------------------------------------------------------------- diff --git a/stmt.go b/stmt.go index 65272102..2b04dbee 100644 --- a/stmt.go +++ b/stmt.go @@ -167,11 +167,13 @@ func (p *switchStmt) Case(cb *CodeBuilder, n int, src ...ast.Node) { cb.panicCodeErrorf( pos, "cannot use %s (type %v) as type %v", src, arg.Type, types.Default(p.tag.Type)) } + cb.recordUpdateUntyped(arg, p.tag.Type) } else { // switch {...} if !types.AssignableTo(arg.Type, types.Typ[types.Bool]) && arg.Type != TyEmptyInterface { src, pos := cb.loadExpr(arg.Src) cb.panicCodeErrorf(pos, "cannot use %s (type %v) as type bool", src, arg.Type) } + cb.recordUpdateUntyped(arg, types.Typ[types.Bool]) } list[i] = arg.Val } diff --git a/type_var_and_const.go b/type_var_and_const.go index 86c01bae..1f803874 100644 --- a/type_var_and_const.go +++ b/type_var_and_const.go @@ -328,6 +328,7 @@ func (p *ValueDecl) endInit(cb *CodeBuilder, arity int) *ValueDecl { if err := matchType(pkg, ret, typ, "assignment"); err != nil { panic(err) } + cb.recordUpdateUntyped(ret, typ) if values != nil { // ret.Val may be changed values[i] = ret.Val } @@ -363,6 +364,7 @@ func (p *ValueDecl) endInit(cb *CodeBuilder, arity int) *ValueDecl { if values != nil { values[i] = parg.Val } + pkg.cb.recordUpdateUntyped(rets[i], retType) if old := p.scope.Insert(types.NewVar(p.pos, pkg.Types, name, retType)); old != nil { if p.tok != token.DEFINE { oldpos := cb.fset.Position(old.Pos())