diff options
Diffstat (limited to 'testing/protocmp/util.go')
-rw-r--r-- | testing/protocmp/util.go | 684 |
1 files changed, 684 insertions, 0 deletions
diff --git a/testing/protocmp/util.go b/testing/protocmp/util.go new file mode 100644 index 00000000..79f3072b --- /dev/null +++ b/testing/protocmp/util.go @@ -0,0 +1,684 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protocmp + +import ( + "bytes" + "fmt" + "math" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +var ( + enumReflectType = reflect.TypeOf(Enum{}) + messageReflectType = reflect.TypeOf(Message{}) +) + +// FilterEnum filters opt to only be applicable on standalone Enums, +// singular fields of enums, list fields of enums, or map fields of enum values, +// where the enum is the same type as the specified enum. +// +// The Go type of the last path step may be an: +// • Enum for singular fields, elements of a repeated field, +// values of a map field, or standalone Enums +// • []Enum for list fields +// • map[K]Enum for map fields +// • interface{} for a Message map entry value +// +// This must be used in conjunction with Transform. +func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option { + return FilterDescriptor(enum.Descriptor(), opt) +} + +// FilterMessage filters opt to only be applicable on standalone Messages, +// singular fields of messages, list fields of messages, or map fields of +// message values, where the message is the same type as the specified message. +// +// The Go type of the last path step may be an: +// • Message for singular fields, elements of a repeated field, +// values of a map field, or standalone Messages +// • []Message for list fields +// • map[K]Message for map fields +// • interface{} for a Message map entry value +// +// This must be used in conjunction with Transform. +func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option { + return FilterDescriptor(message.ProtoReflect().Descriptor(), opt) +} + +// FilterField filters opt to only be applicable on the specified field +// in the message. It panics if a field of the given name does not exist. +// +// The Go type of the last path step may be an: +// • T for singular fields +// • []T for list fields +// • map[K]T for map fields +// • interface{} for a Message map entry value +// +// This must be used in conjunction with Transform. +func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindFieldDescriptor(md, name), opt) +} + +// FilterOneof filters opt to only be applicable on all fields within the +// specified oneof in the message. It panics if a oneof of the given name +// does not exist. +// +// The Go type of the last path step may be an: +// • T for singular fields +// • []T for list fields +// • map[K]T for map fields +// • interface{} for a Message map entry value +// +// This must be used in conjunction with Transform. +func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindOneofDescriptor(md, name), opt) +} + +// FilterDescriptor ignores the specified descriptor. +// +// The following descriptor types may be specified: +// • protoreflect.EnumDescriptor +// • protoreflect.MessageDescriptor +// • protoreflect.FieldDescriptor +// • protoreflect.OneofDescriptor +// +// For the behavior of each, see the corresponding filter function. +// Since this filter accepts a protoreflect.FieldDescriptor, it can be used +// to also filter for extension fields as a protoreflect.ExtensionDescriptor +// is just an alias to protoreflect.FieldDescriptor. +// +// This must be used in conjunction with Transform. +func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option { + f := newNameFilters(desc) + return cmp.FilterPath(f.Filter, opt) +} + +// IgnoreEnums ignores all enums of the specified types. +// It is equivalent to FilterEnum(enum, cmp.Ignore()) for each enum. +// +// This must be used in conjunction with Transform. +func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option { + var ds []protoreflect.Descriptor + for _, e := range enums { + ds = append(ds, e.Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreMessages ignores all messages of the specified types. +// It is equivalent to FilterMessage(message, cmp.Ignore()) for each message. +// +// This must be used in conjunction with Transform. +func IgnoreMessages(messages ...proto.Message) cmp.Option { + var ds []protoreflect.Descriptor + for _, m := range messages { + ds = append(ds, m.ProtoReflect().Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreFields ignores the specified fields in the specified message. +// It is equivalent to FilterField(message, name, cmp.Ignore()) for each field +// in the message. +// +// This must be used in conjunction with Transform. +func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindFieldDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreOneofs ignores fields of the specified oneofs in the specified message. +// It is equivalent to FilterOneof(message, name, cmp.Ignore()) for each oneof +// in the message. +// +// This must be used in conjunction with Transform. +func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindOneofDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreDescriptors ignores the specified set of descriptors. +// It is equivalent to FilterDescriptor(desc, cmp.Ignore()) for each descriptor. +// +// This must be used in conjunction with Transform. +func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option { + return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore()) +} + +func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor { + d := findDescriptor(md, s) + if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) { + return fd + } + + var suggestion string + switch d := d.(type) { + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName()) + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name()) + } + panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion)) +} + +func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor { + d := findDescriptor(md, s) + if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s { + return od + } + + var suggestion string + switch d := d.(type) { + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name()) + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName()) + } + panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion)) +} + +func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor { + // Exact match. + if fd := md.Fields().ByTextName(string(s)); fd != nil { + return fd + } + if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() { + return od + } + + // Best-effort match. + // + // It's a common user mistake to use the CamelCased field name as it appears + // in the generated Go struct. Instead of complaining that it doesn't exist, + // suggest the real protobuf name that the user may have desired. + normalize := func(s protoreflect.Name) string { + return strings.Replace(strings.ToLower(string(s)), "_", "", -1) + } + for i := 0; i < md.Fields().Len(); i++ { + if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) { + return fd + } + } + for i := 0; i < md.Oneofs().Len(); i++ { + if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) { + return od + } + } + return nil +} + +type nameFilters struct { + names map[protoreflect.FullName]bool +} + +func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters { + f := &nameFilters{names: make(map[protoreflect.FullName]bool)} + for _, d := range descs { + switch d := d.(type) { + case protoreflect.EnumDescriptor: + f.names[d.FullName()] = true + case protoreflect.MessageDescriptor: + f.names[d.FullName()] = true + case protoreflect.FieldDescriptor: + f.names[d.FullName()] = true + case protoreflect.OneofDescriptor: + for i := 0; i < d.Fields().Len(); i++ { + f.names[d.Fields().Get(i).FullName()] = true + } + default: + panic("invalid descriptor type") + } + } + return f +} + +func (f *nameFilters) Filter(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p) +} + +func (f *nameFilters) filterFields(p cmp.Path) bool { + // Trim off trailing type-assertions so that the filter can match on the + // concrete value held within an interface value. + if _, ok := p.Last().(cmp.TypeAssertion); ok { + p = p[:len(p)-1] + } + + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field name. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + if f.filterFieldName(mx, k) && f.filterFieldName(my, k) { + return true + } + + // Check field value. + vx, vy = mi.Values() + if f.filterFieldValue(vx) && f.filterFieldValue(vy) { + return true + } + + return false +} + +func (f *nameFilters) filterFieldName(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true // treat missing fields as already filtered + } + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd != nil { + return f.names[fd.FullName()] + } + return false +} + +func (f *nameFilters) filterFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == enumReflectType || t == messageReflectType: + // Check for singular message or enum field. + return f.filterValue(v) + case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for list field of enum or message type. + return f.filterValue(v.Index(0)) + case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for map field of enum or message type. + return f.filterValue(v.MapIndex(v.MapKeys()[0])) + } + return false +} + +func (f *nameFilters) filterValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + switch v := v.Interface().(type) { + case Enum: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + case Message: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + } + return false +} + +// IgnoreDefaultScalars ignores singular scalars that are unpopulated or +// explicitly set to the default value. +// This option does not effect elements in a list or entries in a map. +// +// This must be used in conjunction with Transform. +func IgnoreDefaultScalars() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check whether both fields are default or unpopulated scalars. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + return isDefaultScalar(mx, k) && isDefaultScalar(my, k) + }, cmp.Ignore()) +} + +func isDefaultScalar(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true + } + + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd == nil || !fd.Default().IsValid() { + return false + } + switch fd.Kind() { + case protoreflect.BytesKind: + v, ok := m[k].([]byte) + return ok && bytes.Equal(fd.Default().Bytes(), v) + case protoreflect.FloatKind: + v, ok := m[k].(float32) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.DoubleKind: + v, ok := m[k].(float64) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.EnumKind: + v, ok := m[k].(Enum) + return ok && fd.Default().Enum() == v.Number() + default: + return reflect.DeepEqual(fd.Default().Interface(), m[k]) + } +} + +func equalFloat64(x, y float64) bool { + return x == y || (math.IsNaN(x) && math.IsNaN(y)) +} + +// IgnoreEmptyMessages ignores messages that are empty or unpopulated. +// It applies to standalone Messages, singular message fields, +// list fields of messages, and map fields of message values. +// +// This must be used in conjunction with Transform. +func IgnoreEmptyMessages() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p) + }, cmp.Ignore()) +} + +func isEmptyMessageFields(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field value. + vx, vy := mi.Values() + if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) { + return true + } + + return false +} + +func isEmptyMessageFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == messageReflectType: + // Check singular field for empty message. + if !isEmptyMessage(v) { + return false + } + case t.Kind() == reflect.Slice && t.Elem() == messageReflectType: + // Check list field for all empty message elements. + for i := 0; i < v.Len(); i++ { + if !isEmptyMessage(v.Index(i)) { + return false + } + } + case t.Kind() == reflect.Map && t.Elem() == messageReflectType: + // Check map field for all empty message values. + for _, k := range v.MapKeys() { + if !isEmptyMessage(v.MapIndex(k)) { + return false + } + } + default: + return false + } + return true +} + +func isEmptyMessage(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + if m, ok := v.Interface().(Message); ok { + for k := range m { + if k != messageTypeKey && k != messageInvalidKey { + return false + } + } + return true + } + return false +} + +// IgnoreUnknown ignores unknown fields in all messages. +// +// This must be used in conjunction with Transform. +func IgnoreUnknown() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Filter for unknown fields (which always have a numeric map key). + return strings.Trim(mi.Key().String(), "0123456789") == "" + }, cmp.Ignore()) +} + +// SortRepeated sorts repeated fields of the specified element type. +// The less function must be of the form "func(T, T) bool" where T is the +// Go element type for the repeated field kind. +// +// The element type T can be one of the following: +// • Go type for a protobuf scalar kind except for an enum +// (i.e., bool, int32, int64, uint32, uint64, float32, float64, string, and []byte) +// • E where E is a concrete enum type that implements protoreflect.Enum +// • M where M is a concrete message type that implement proto.Message +// +// This option only applies to repeated fields within a protobuf message. +// It does not operate on higher-order Go types that seem like a repeated field. +// For example, a []T outside the context of a protobuf message will not be +// handled by this option. To sort Go slices that are not repeated fields, +// consider using "github.com/google/go-cmp/cmp/cmpopts".SortSlices instead. +// +// This must be used in conjunction with Transform. +func SortRepeated(lessFunc interface{}) cmp.Option { + t, ok := checkTTBFunc(lessFunc) + if !ok { + panic(fmt.Sprintf("invalid less function: %T", lessFunc)) + } + + var opt cmp.Option + var sliceType reflect.Type + switch vf := reflect.ValueOf(lessFunc); { + case t.Implements(enumV2Type): + et := reflect.Zero(t).Interface().(protoreflect.Enum).Type() + lessFunc = func(x, y Enum) bool { + vx := reflect.ValueOf(et.New(x.Number())) + vy := reflect.ValueOf(et.New(y.Number())) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(enumReflectType) + case t.Implements(messageV2Type): + mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type() + lessFunc = func(x, y Message) bool { + mx := mt.New().Interface() + my := mt.New().Interface() + proto.Merge(mx, x) + proto.Merge(my, y) + vx := reflect.ValueOf(mx) + vy := reflect.ValueOf(my) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(messageReflectType) + default: + switch t { + case reflect.TypeOf(bool(false)): + case reflect.TypeOf(int32(0)): + case reflect.TypeOf(int64(0)): + case reflect.TypeOf(uint32(0)): + case reflect.TypeOf(uint64(0)): + case reflect.TypeOf(float32(0)): + case reflect.TypeOf(float64(0)): + case reflect.TypeOf(string("")): + case reflect.TypeOf([]byte(nil)): + default: + panic(fmt.Sprintf("invalid element type: %v", t)) + } + opt = cmpopts.SortSlices(lessFunc) + sliceType = reflect.SliceOf(t) + } + + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter to only apply to repeated fields within a message. + if t := p.Index(-1).Type(); t == nil || t != sliceType { + return false + } + if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface { + return false + } + if t := p.Index(-3).Type(); t == nil || t != messageReflectType { + return false + } + return true + }, opt) +} + +func checkTTBFunc(lessFunc interface{}) (reflect.Type, bool) { + switch t := reflect.TypeOf(lessFunc); { + case t == nil: + return nil, false + case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic(): + return nil, false + case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false): + return nil, false + default: + return t.In(0), true + } +} + +// SortRepeatedFields sorts the specified repeated fields. +// Sorting a repeated field is useful for treating the list as a multiset +// (i.e., a set where each value can appear multiple times). +// It panics if the field does not exist or is not a repeated field. +// +// The sort ordering is as follows: +// • Booleans are sorted where false is sorted before true. +// • Integers are sorted in ascending order. +// • Floating-point numbers are sorted in ascending order according to +// the total ordering defined by IEEE-754 (section 5.10). +// • Strings and bytes are sorted lexicographically in ascending order. +// • Enums are sorted in ascending order based on its numeric value. +// • Messages are sorted according to some arbitrary ordering +// which is undefined and may change in future implementations. +// +// The ordering chosen for repeated messages is unlikely to be aesthetically +// preferred by humans. Consider using a custom sort function: +// +// FilterField(m, "foo_field", SortRepeated(func(x, y *foopb.MyMessage) bool { +// ... // user-provided definition for less +// })) +// +// This must be used in conjunction with Transform. +func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var opts cmp.Options + md := message.ProtoReflect().Descriptor() + for _, name := range names { + fd := mustFindFieldDescriptor(md, name) + if !fd.IsList() { + panic(fmt.Sprintf("message field %q is not repeated", fd.FullName())) + } + + var lessFunc interface{} + switch fd.Kind() { + case protoreflect.BoolKind: + lessFunc = func(x, y bool) bool { return !x && y } + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + lessFunc = func(x, y int32) bool { return x < y } + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + lessFunc = func(x, y int64) bool { return x < y } + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + lessFunc = func(x, y uint32) bool { return x < y } + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + lessFunc = func(x, y uint64) bool { return x < y } + case protoreflect.FloatKind: + lessFunc = lessF32 + case protoreflect.DoubleKind: + lessFunc = lessF64 + case protoreflect.StringKind: + lessFunc = func(x, y string) bool { return x < y } + case protoreflect.BytesKind: + lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 } + case protoreflect.EnumKind: + lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() } + case protoreflect.MessageKind, protoreflect.GroupKind: + lessFunc = func(x, y Message) bool { return x.String() < y.String() } + default: + panic(fmt.Sprintf("invalid kind: %v", fd.Kind())) + } + opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc))) + } + return opts +} + +func lessF32(x, y float32) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int32(math.Float32bits(x)) + yi := int32(math.Float32bits(y)) + xi ^= int32(uint32(xi>>31) >> 1) + yi ^= int32(uint32(yi>>31) >> 1) + return xi < yi +} +func lessF64(x, y float64) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int64(math.Float64bits(x)) + yi := int64(math.Float64bits(y)) + xi ^= int64(uint64(xi>>63) >> 1) + yi ^= int64(uint64(yi>>63) >> 1) + return xi < yi +} |