aboutsummaryrefslogtreecommitdiff
path: root/cmp/cmpopts/equate.go
blob: 3d8d0cd3ae37b7b5db6755a0f247fd9435dcb294 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package cmpopts provides common options for the cmp package.
package cmpopts

import (
	"errors"
	"fmt"
	"math"
	"reflect"
	"time"

	"github.com/google/go-cmp/cmp"
)

func equateAlways(_, _ interface{}) bool { return true }

// EquateEmpty returns a [cmp.Comparer] option that determines all maps and slices
// with a length of zero to be equal, regardless of whether they are nil.
//
// EquateEmpty can be used in conjunction with [SortSlices] and [SortMaps].
func EquateEmpty() cmp.Option {
	return cmp.FilterValues(isEmpty, cmp.Comparer(equateAlways))
}

func isEmpty(x, y interface{}) bool {
	vx, vy := reflect.ValueOf(x), reflect.ValueOf(y)
	return (x != nil && y != nil && vx.Type() == vy.Type()) &&
		(vx.Kind() == reflect.Slice || vx.Kind() == reflect.Map) &&
		(vx.Len() == 0 && vy.Len() == 0)
}

// EquateApprox returns a [cmp.Comparer] option that determines float32 or float64
// values to be equal if they are within a relative fraction or absolute margin.
// This option is not used when either x or y is NaN or infinite.
//
// The fraction determines that the difference of two values must be within the
// smaller fraction of the two values, while the margin determines that the two
// values must be within some absolute margin.
// To express only a fraction or only a margin, use 0 for the other parameter.
// The fraction and margin must be non-negative.
//
// The mathematical expression used is equivalent to:
//
//	|x-y| ≤ max(fraction*min(|x|, |y|), margin)
//
// EquateApprox can be used in conjunction with [EquateNaNs].
func EquateApprox(fraction, margin float64) cmp.Option {
	if margin < 0 || fraction < 0 || math.IsNaN(margin) || math.IsNaN(fraction) {
		panic("margin or fraction must be a non-negative number")
	}
	a := approximator{fraction, margin}
	return cmp.Options{
		cmp.FilterValues(areRealF64s, cmp.Comparer(a.compareF64)),
		cmp.FilterValues(areRealF32s, cmp.Comparer(a.compareF32)),
	}
}

type approximator struct{ frac, marg float64 }

func areRealF64s(x, y float64) bool {
	return !math.IsNaN(x) && !math.IsNaN(y) && !math.IsInf(x, 0) && !math.IsInf(y, 0)
}
func areRealF32s(x, y float32) bool {
	return areRealF64s(float64(x), float64(y))
}
func (a approximator) compareF64(x, y float64) bool {
	relMarg := a.frac * math.Min(math.Abs(x), math.Abs(y))
	return math.Abs(x-y) <= math.Max(a.marg, relMarg)
}
func (a approximator) compareF32(x, y float32) bool {
	return a.compareF64(float64(x), float64(y))
}

// EquateNaNs returns a [cmp.Comparer] option that determines float32 and float64
// NaN values to be equal.
//
// EquateNaNs can be used in conjunction with [EquateApprox].
func EquateNaNs() cmp.Option {
	return cmp.Options{
		cmp.FilterValues(areNaNsF64s, cmp.Comparer(equateAlways)),
		cmp.FilterValues(areNaNsF32s, cmp.Comparer(equateAlways)),
	}
}

func areNaNsF64s(x, y float64) bool {
	return math.IsNaN(x) && math.IsNaN(y)
}
func areNaNsF32s(x, y float32) bool {
	return areNaNsF64s(float64(x), float64(y))
}

// EquateApproxTime returns a [cmp.Comparer] option that determines two non-zero
// [time.Time] values to be equal if they are within some margin of one another.
// If both times have a monotonic clock reading, then the monotonic time
// difference will be used. The margin must be non-negative.
func EquateApproxTime(margin time.Duration) cmp.Option {
	if margin < 0 {
		panic("margin must be a non-negative number")
	}
	a := timeApproximator{margin}
	return cmp.FilterValues(areNonZeroTimes, cmp.Comparer(a.compare))
}

func areNonZeroTimes(x, y time.Time) bool {
	return !x.IsZero() && !y.IsZero()
}

type timeApproximator struct {
	margin time.Duration
}

func (a timeApproximator) compare(x, y time.Time) bool {
	// Avoid subtracting times to avoid overflow when the
	// difference is larger than the largest representable duration.
	if x.After(y) {
		// Ensure x is always before y
		x, y = y, x
	}
	// We're within the margin if x+margin >= y.
	// Note: time.Time doesn't have AfterOrEqual method hence the negation.
	return !x.Add(a.margin).Before(y)
}

// AnyError is an error that matches any non-nil error.
var AnyError anyError

type anyError struct{}

func (anyError) Error() string     { return "any error" }
func (anyError) Is(err error) bool { return err != nil }

// EquateErrors returns a [cmp.Comparer] option that determines errors to be equal
// if [errors.Is] reports them to match. The [AnyError] error can be used to
// match any non-nil error.
func EquateErrors() cmp.Option {
	return cmp.FilterValues(areConcreteErrors, cmp.Comparer(compareErrors))
}

// areConcreteErrors reports whether x and y are types that implement error.
// The input types are deliberately of the interface{} type rather than the
// error type so that we can handle situations where the current type is an
// interface{}, but the underlying concrete types both happen to implement
// the error interface.
func areConcreteErrors(x, y interface{}) bool {
	_, ok1 := x.(error)
	_, ok2 := y.(error)
	return ok1 && ok2
}

func compareErrors(x, y interface{}) bool {
	xe := x.(error)
	ye := y.(error)
	return errors.Is(xe, ye) || errors.Is(ye, xe)
}

// EquateComparable returns a [cmp.Option] that determines equality
// of comparable types by directly comparing them using the == operator in Go.
// The types to compare are specified by passing a value of that type.
// This option should only be used on types that are documented as being
// safe for direct == comparison. For example, [net/netip.Addr] is documented
// as being semantically safe to use with ==, while [time.Time] is documented
// to discourage the use of == on time values.
func EquateComparable(typs ...interface{}) cmp.Option {
	types := make(typesFilter)
	for _, typ := range typs {
		switch t := reflect.TypeOf(typ); {
		case !t.Comparable():
			panic(fmt.Sprintf("%T is not a comparable Go type", typ))
		case types[t]:
			panic(fmt.Sprintf("%T is already specified", typ))
		default:
			types[t] = true
		}
	}
	return cmp.FilterPath(types.filter, cmp.Comparer(equateAny))
}

type typesFilter map[reflect.Type]bool

func (tf typesFilter) filter(p cmp.Path) bool { return tf[p.Last().Type()] }

func equateAny(x, y interface{}) bool { return x == y }