aboutsummaryrefslogtreecommitdiff
path: root/go/analysis/passes/sigchanyzer/sigchanyzer.go
blob: 0d6c8ebf1655f467d9e635c5d4259e8ae1ac042d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package sigchanyzer defines an Analyzer that detects
// misuse of unbuffered signal as argument to signal.Notify.
package sigchanyzer

import (
	"bytes"
	"go/ast"
	"go/format"
	"go/token"
	"go/types"

	"golang.org/x/tools/go/analysis"
	"golang.org/x/tools/go/analysis/passes/inspect"
	"golang.org/x/tools/go/ast/inspector"
)

const Doc = `check for unbuffered channel of os.Signal

This checker reports call expression of the form signal.Notify(c <-chan os.Signal, sig ...os.Signal),
where c is an unbuffered channel, which can be at risk of missing the signal.`

// Analyzer describes sigchanyzer analysis function detector.
var Analyzer = &analysis.Analyzer{
	Name:     "sigchanyzer",
	Doc:      Doc,
	Requires: []*analysis.Analyzer{inspect.Analyzer},
	Run:      run,
}

func run(pass *analysis.Pass) (interface{}, error) {
	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

	nodeFilter := []ast.Node{
		(*ast.CallExpr)(nil),
	}
	inspect.Preorder(nodeFilter, func(n ast.Node) {
		call := n.(*ast.CallExpr)
		if !isSignalNotify(pass.TypesInfo, call) {
			return
		}
		var chanDecl *ast.CallExpr
		switch arg := call.Args[0].(type) {
		case *ast.Ident:
			if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
				chanDecl = decl
			}
		case *ast.CallExpr:
			// Only signal.Notify(make(chan os.Signal), os.Interrupt) is safe,
			// conservatively treate others as not safe, see golang/go#45043
			if isBuiltinMake(pass.TypesInfo, arg) {
				return
			}
			chanDecl = arg
		}
		if chanDecl == nil || len(chanDecl.Args) != 1 {
			return
		}

		// Make a copy of the channel's declaration to avoid
		// mutating the AST. See https://golang.org/issue/46129.
		chanDeclCopy := &ast.CallExpr{}
		*chanDeclCopy = *chanDecl
		chanDeclCopy.Args = append([]ast.Expr(nil), chanDecl.Args...)
		chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
			Kind:  token.INT,
			Value: "1",
		})

		var buf bytes.Buffer
		if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
			return
		}
		pass.Report(analysis.Diagnostic{
			Pos:     call.Pos(),
			End:     call.End(),
			Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
			SuggestedFixes: []analysis.SuggestedFix{{
				Message: "Change to buffer channel",
				TextEdits: []analysis.TextEdit{{
					Pos:     chanDecl.Pos(),
					End:     chanDecl.End(),
					NewText: buf.Bytes(),
				}},
			}},
		})
	})
	return nil, nil
}

func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
	check := func(id *ast.Ident) bool {
		obj := info.ObjectOf(id)
		return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
	}
	switch fun := call.Fun.(type) {
	case *ast.SelectorExpr:
		return check(fun.Sel)
	case *ast.Ident:
		if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
			return check(fun.Sel)
		}
		return false
	default:
		return false
	}
}

func findDecl(arg *ast.Ident) ast.Node {
	if arg.Obj == nil {
		return nil
	}
	switch as := arg.Obj.Decl.(type) {
	case *ast.AssignStmt:
		if len(as.Lhs) != len(as.Rhs) {
			return nil
		}
		for i, lhs := range as.Lhs {
			lid, ok := lhs.(*ast.Ident)
			if !ok {
				continue
			}
			if lid.Obj == arg.Obj {
				return as.Rhs[i]
			}
		}
	case *ast.ValueSpec:
		if len(as.Names) != len(as.Values) {
			return nil
		}
		for i, name := range as.Names {
			if name.Obj == arg.Obj {
				return as.Values[i]
			}
		}
	}
	return nil
}

func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
	typVal := info.Types[call.Fun]
	if !typVal.IsBuiltin() {
		return false
	}
	switch fun := call.Fun.(type) {
	case *ast.Ident:
		return info.ObjectOf(fun).Name() == "make"
	default:
		return false
	}
}