Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions hook/assume_return.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@ import (
"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/util"
"go.uber.org/nilaway/util/analysishelper"
"go.uber.org/nilaway/util/typeshelper"
)

// AssumeReturn returns the producer for the return value of the given call expression, which would
// have the assumed nilability. This is useful for modeling the return value of stdlib and 3rd party
// functions that are not analyzed by NilAway. For example, "errors.New" is assumed to return a
// nonnil value. If the given call expression does not match any known function, nil is returned.
func AssumeReturn(pass *analysishelper.EnhancedPass, call *ast.CallExpr) *annotation.ProduceTrigger {
if trigger := matchTrustedFuncs(pass, call); trigger != nil {
return trigger
}
return AssumeReturnForErrorWrapperFunc(pass, call)
}

func matchTrustedFuncs(pass *analysishelper.EnhancedPass, call *ast.CallExpr) *annotation.ProduceTrigger {
for sig, act := range _assumeReturns {
if sig.match(pass, call) {
return act(call)
}
}

return AssumeReturnForErrorWrapperFunc(pass, call)
return nil
}

// AssumeReturnForErrorWrapperFunc returns the producer for the return value of the given call expression which is
Expand All @@ -48,6 +55,8 @@ func AssumeReturnForErrorWrapperFunc(pass *analysishelper.EnhancedPass, call *as
return nil
}

var _newErrorFuncNameRegex = regexp.MustCompile(`(?i)new[^ ]*error[^ ]*`)

// isErrorWrapperFunc implements a heuristic to identify error wrapper functions (e.g., `errors.Wrapf(err, "message")`).
// It does this by applying the following criteria:
// - the function must have at least one argument of error-implementing type, and
Expand All @@ -58,53 +67,64 @@ func isErrorWrapperFunc(pass *analysishelper.EnhancedPass, call *ast.CallExpr) b
return false
}

obj := pass.TypesInfo.ObjectOf(funcIdent)
if obj == nil {
// Return early if the function object is nil or does not return an error.
var funcObj *types.Func
if obj := pass.TypesInfo.ObjectOf(funcIdent); obj != nil {
if fObj, ok := obj.(*types.Func); ok {
if util.FuncIsErrReturning(typeshelper.GetFuncSignature(fObj.Signature())) {
funcObj = fObj
}
}
}
if funcObj == nil {
return false
}

// If the call expr is built-in `new`, then we check if its argument type implements the error interface.
// This case particularly gets triggered for the expression: `Wrap(new(MyErrorStruct), "message")`.
if obj == util.BuiltinNew {
if argIdent := util.IdentOf(call.Args[0]); argIdent != nil {
ptr := types.NewPointer(pass.TypesInfo.TypeOf(argIdent))
if types.Implements(ptr, util.ErrorInterface) {
// Check if the function is an error wrapper: consumes an error and returns an error.
for _, arg := range call.Args {
// Check if the argument is a call expression.
if callExpr, ok := arg.(*ast.CallExpr); ok {
if matchTrustedFuncs(pass, callExpr) != nil {
// Check if the argument is a trusted error returning function call.
// Example: `wrapError(errors.New("new error"))`
return true
}
}

return false
}

funcObj, ok := obj.(*types.Func)
if !ok {
return false
}
if util.FuncIsErrReturning(funcObj.Signature()) {
args := call.Args

// If the function is a method, we need to check if the receiver is an error-implementing type.
// This is to cover the case where some error wrappers facilitate a chaining functionality, i.e., the receiver
// is an error-implementing type (e.g., Wrap().WithOtherFields()). By adding the receiver to the argument list,
// we can check if it is an error-implementing type and support this case.
if funcObj.Type().(*types.Signature).Recv() != nil {
args = append(args, call.Fun)
}
for _, arg := range args {
if callExpr, ok := ast.Unparen(arg).(*ast.CallExpr); ok {
if isErrorWrapperFunc(pass, callExpr) {
// This is to cover the case `NewInternalError(err.Error())` where the argument is a method call on an error.
// We want to extract the raw error argument `err` in this case.
if s, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
t := pass.TypesInfo.TypeOf(s.X)
if t != nil && util.ImplementsError(t) {
return true
}
}

if argIdent := util.IdentOf(arg); argIdent != nil {
argObj := pass.TypesInfo.ObjectOf(argIdent)
if util.ImplementsError(argObj) {
return true
}
// Recursively check if the argument is an error wrapper function call.
// Example: `wrapError(wrapError(wrapError(err)))`
if isErrorWrapperFunc(pass, callExpr) {
return true
}
}

argType := pass.TypesInfo.TypeOf(arg)
if argType != nil && util.ImplementsError(argType) {
// Return the raw error argument expression
return true
}
}

// Check if the function is creating a new error:
// - consumes a message string and returns an error
// - matches regex "new*error" as its function name (e.g., `NewInternalError()`)
if _newErrorFuncNameRegex.MatchString(funcObj.Name()) {
for i := 0; i < funcObj.Signature().Params().Len(); i++ {
param := funcObj.Signature().Params().At(i)
if t, ok := param.Type().(*types.Basic); ok && t.Kind() == types.String && i < len(call.Args) {
return true
}
}
}

return false
}

Expand Down Expand Up @@ -136,6 +156,18 @@ var _assumeReturns = map[trustedFuncSig]assumeReturnAction{
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/pkg/errors$`),
funcNameRegex: regexp.MustCompile(`^New$`),
}: nonnilProducer,

// `errors.Join`
// Note that `errors.Join` can return nil if all arguments are nil [1]. However, in practice this should rarely
// happen such that we assume it returns a non-nil error for simplicity. Here we are making a conscious trade-off
// between soundness and practicality.
//
// [1] https://pkg.go.dev/errors#Join
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^errors$`),
funcNameRegex: regexp.MustCompile(`^Join$`),
}: nonnilProducer,
}

var nonnilProducer assumeReturnAction = func(call *ast.CallExpr) *annotation.ProduceTrigger {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ package inference

import (
"errors"
"fmt"

"go.uber.org/errorreturn"
"go.uber.org/errorreturn/inference/otherPkg"
)

var dummy2 bool

type myErr2 struct{}
type myErr2 struct {
msg string
}

func (myErr2) Error() string { return "myErr2 message" }

Expand All @@ -39,6 +42,14 @@ func retNonNilErr2() error {
return &myErr2{}
}

func NewError(msg string) error {
return &myErr2{msg: msg}
}

func NewInternalError(msg string) error {
return &myErr2{msg: msg}
}

// ***** the below test case checks error return via a function and assigned to a variable *****
func retPtrAndErr2(i int) (*int, error) {
if dummy2 {
Expand Down Expand Up @@ -439,21 +450,18 @@ func Wrap(err error, msg string) WrappedErr {

func GetFirstErr(errs ...error) error {
if len(errs) == 0 {
return nil
return fmt.Errorf("GetFirstErr called with no errors")
}
return errs[0]
}

func GetFirstErrArr(errs [2]error) error {
if errs[0] == nil && errs[1] == nil {
return nil
}
return errs[0]
}

func GetErrPtr(e *error) error {
if e == nil {
return nil
return fmt.Errorf("GetErrPtr called with nil error")
}
return *e
}
Expand Down Expand Up @@ -554,7 +562,7 @@ func callTestErrorWrapper(i int) {
if err != nil {
return
}
_ = *x //want "dereferenced"
_ = *x

case 3:
x, err := testErrorWrapper3()
Expand Down Expand Up @@ -1069,3 +1077,94 @@ func TestGenericFunc(s string) {
}
}
}

func retPtrErrForNewInternalErrorLitString() (*int, error) {
resp, err := retPtrAndErr3()
if err != nil {
return nil, NewInternalError("some other error")
}
return resp, nil
}

func retPtrErrForNewInternalErrorCallExpr() (*int, error) {
resp, err := retPtrAndErr3()
if err != nil {
return nil, NewInternalError(err.Error())
}
return resp, nil
}

func retPtrErrForNewErrorLitString() (*int, error) {
resp, err := retPtrAndErr3()
if err != nil {
return nil, NewError("some other error")
}
return resp, nil
}

func retPtrErrWrappedNewErrorUnsafe() (*int, error) {
resp, err := retPtrAndErr3()
return resp, Wrapf(Wrapf(Wrapf(NewError(err.Error())))) //want "called `Error"
}

func retPtrErrWrappedNewErrorSafe() (*int, error) {
resp, err := retPtrAndErr3()
if err != nil {
return resp, Wrapf(Wrapf(Wrapf(NewError(err.Error()))))
}
return resp, nil
}

func retPtrErrUnwrap() (*int, error) {
resp, err := retPtrAndErr3()
if err != nil {
return resp, NewError(errors.Unwrap(err).Error())
}
return resp, nil
}

func TestNewError(s string) {
switch s {
case "NewInternalError with literal string message":
resp, err := retPtrErrForNewInternalErrorLitString()
if err != nil {
return
}
_ = *resp

case "NewInternalError with err.Error()":
resp, err := retPtrErrForNewInternalErrorCallExpr()
if err != nil {
return
}
_ = *resp

case "NewError with literal string message":
resp, err := retPtrErrForNewErrorLitString()
if err != nil {
return
}
_ = *resp

case "wrapped error with NewError, unsafe":
resp, err := retPtrErrWrappedNewErrorUnsafe()
if err != nil {
return
}
_ = *resp

case "wrapped error with NewError, safe":
resp, err := retPtrErrWrappedNewErrorSafe()
if err != nil {
return
}
_ = *resp

case "errors.Unwrap with NewError":
resp, err := retPtrErrUnwrap()
if err != nil {
return
}
_ = *resp
}
}
24 changes: 4 additions & 20 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,10 @@ func IsEmptyExpr(expr ast.Expr) bool {

// ImplementsError checks if the given object implements the error interface. It also covers the case of
// interfaces that embed the error interface.
func ImplementsError(obj types.Object) bool {
if ErrorInterface == nil || obj == nil {
func ImplementsError(t types.Type) bool {
if t == nil {
return false
}

underlyingType := func(t types.Type) types.Type {
switch t := t.(type) {
case *types.Pointer:
return UnwrapPtr(t)
case *types.Slice:
return t.Elem().Underlying()
case *types.Array:
return t.Elem().Underlying()
default:
return t
}
}

t := underlyingType(obj.Type())

return types.Implements(t, ErrorInterface)
}

Expand All @@ -359,12 +343,12 @@ func FuncIsErrReturning(sig *types.Signature) bool {
}

errRes := results.At(n - 1)
if !ImplementsError(errRes) {
if !ImplementsError(errRes.Type()) {
return false
}

for i := 0; i < n-1; i++ {
if ImplementsError(results.At(i)) {
if ImplementsError(results.At(i).Type()) {
return false
}
}
Expand Down
Loading