diff options
Diffstat (limited to 'reflect/protoregistry/registry_test.go')
-rw-r--r-- | reflect/protoregistry/registry_test.go | 655 |
1 files changed, 655 insertions, 0 deletions
diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go new file mode 100644 index 00000000..d058c196 --- /dev/null +++ b/reflect/protoregistry/registry_test.go @@ -0,0 +1,655 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protoregistry_test + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "google.golang.org/protobuf/encoding/prototext" + pimpl "google.golang.org/protobuf/internal/impl" + pdesc "google.golang.org/protobuf/reflect/protodesc" + pref "google.golang.org/protobuf/reflect/protoreflect" + preg "google.golang.org/protobuf/reflect/protoregistry" + + testpb "google.golang.org/protobuf/internal/testprotos/registry" + "google.golang.org/protobuf/types/descriptorpb" +) + +func mustMakeFile(s string) pref.FileDescriptor { + pb := new(descriptorpb.FileDescriptorProto) + if err := prototext.Unmarshal([]byte(s), pb); err != nil { + panic(err) + } + fd, err := pdesc.NewFile(pb, nil) + if err != nil { + panic(err) + } + return fd +} + +func TestFiles(t *testing.T) { + type ( + file struct { + Path string + Pkg pref.FullName + } + testFile struct { + inFile pref.FileDescriptor + wantErr string + } + testFindDesc struct { + inName pref.FullName + wantFound bool + } + testRangePkg struct { + inPkg pref.FullName + wantFiles []file + } + testFindPath struct { + inPath string + wantFiles []file + wantErr string + } + ) + + tests := []struct { + files []testFile + findDescs []testFindDesc + rangePkgs []testRangePkg + findPaths []testFindPath + }{{ + // Test that overlapping packages and files are permitted. + files: []testFile{ + {inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)}, + {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)}, + {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"}, + {inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)}, + {inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)}, + {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)}, + }, + + rangePkgs: []testRangePkg{{ + inPkg: "nothing", + }, { + inPkg: "", + }, { + inPkg: ".", + }, { + inPkg: "foo", + }, { + inPkg: "foo.", + }, { + inPkg: "foo..", + }, { + inPkg: "foo.bar", + wantFiles: []file{ + {"test1.proto", "foo.bar"}, + {"weird", "foo.bar"}, + }, + }, { + inPkg: "my.test", + wantFiles: []file{ + {"foo/bar/baz/../test.proto", "my.test"}, + {"foo/bar/test.proto", "my.test"}, + }, + }, { + inPkg: "fo", + }}, + + findPaths: []testFindPath{{ + inPath: "nothing", + wantErr: "not found", + }, { + inPath: "weird", + wantFiles: []file{ + {"weird", "foo.bar"}, + }, + }, { + inPath: "foo/bar/test.proto", + wantFiles: []file{ + {"foo/bar/test.proto", "my.test"}, + }, + }}, + }, { + // Test when new enum conflicts with existing package. + files: []testFile{{ + inFile: mustMakeFile(`syntax:"proto2" name:"test1a.proto" package:"foo.bar.baz"`), + }, { + inFile: mustMakeFile(`syntax:"proto2" name:"test1b.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`), + wantErr: `file "test1b.proto" has a name conflict over foo`, + }}, + }, { + // Test when new package conflicts with existing enum. + files: []testFile{{ + inFile: mustMakeFile(`syntax:"proto2" name:"test2a.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`), + }, { + inFile: mustMakeFile(`syntax:"proto2" name:"test2b.proto" package:"foo.bar.baz"`), + wantErr: `file "test2b.proto" has a package name conflict over foo`, + }}, + }, { + // Test when new enum conflicts with existing enum in same package. + files: []testFile{{ + inFile: mustMakeFile(`syntax:"proto2" name:"test3a.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE" number:0}]}]`), + }, { + inFile: mustMakeFile(`syntax:"proto2" name:"test3b.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE2" number:0}]}]`), + wantErr: `file "test3b.proto" has a name conflict over foo.BAR`, + }}, + }, { + files: []testFile{{ + inFile: mustMakeFile(` + syntax: "proto2" + name: "test1.proto" + package: "fizz.buzz" + message_type: [{ + name: "Message" + field: [ + {name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING oneof_index:0} + ] + oneof_decl: [{name:"Oneof"}] + extension_range: [{start:1000 end:2000}] + + enum_type: [ + {name:"Enum" value:[{name:"EnumValue" number:0}]} + ] + nested_type: [ + {name:"Message" field:[{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}]} + ] + extension: [ + {name:"Extension" number:1001 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"} + ] + }] + enum_type: [{ + name: "Enum" + value: [{name:"EnumValue" number:0}] + }] + extension: [ + {name:"Extension" number:1000 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"} + ] + service: [{ + name: "Service" + method: [{ + name: "Method" + input_type: ".fizz.buzz.Message" + output_type: ".fizz.buzz.Message" + client_streaming: true + server_streaming: true + }] + }] + `), + }, { + inFile: mustMakeFile(` + syntax: "proto2" + name: "test2.proto" + package: "fizz.buzz.gazz" + enum_type: [{ + name: "Enum" + value: [{name:"EnumValue" number:0}] + }] + `), + }, { + inFile: mustMakeFile(` + syntax: "proto2" + name: "test3.proto" + package: "fizz.buzz" + enum_type: [{ + name: "Enum1" + value: [{name:"EnumValue1" number:0}] + }, { + name: "Enum2" + value: [{name:"EnumValue2" number:0}] + }] + `), + }, { + // Make sure we can register without package name. + inFile: mustMakeFile(` + name: "weird" + syntax: "proto2" + message_type: [{ + name: "Message" + nested_type: [{ + name: "Message" + nested_type: [{ + name: "Message" + }] + }] + }] + `), + }}, + findDescs: []testFindDesc{ + {inName: "fizz.buzz.message", wantFound: false}, + {inName: "fizz.buzz.Message", wantFound: true}, + {inName: "fizz.buzz.Message.X", wantFound: false}, + {inName: "fizz.buzz.Field", wantFound: false}, + {inName: "fizz.buzz.Oneof", wantFound: false}, + {inName: "fizz.buzz.Message.Field", wantFound: true}, + {inName: "fizz.buzz.Message.Field.X", wantFound: false}, + {inName: "fizz.buzz.Message.Oneof", wantFound: true}, + {inName: "fizz.buzz.Message.Oneof.X", wantFound: false}, + {inName: "fizz.buzz.Message.Message", wantFound: true}, + {inName: "fizz.buzz.Message.Message.X", wantFound: false}, + {inName: "fizz.buzz.Message.Enum", wantFound: true}, + {inName: "fizz.buzz.Message.Enum.X", wantFound: false}, + {inName: "fizz.buzz.Message.EnumValue", wantFound: true}, + {inName: "fizz.buzz.Message.EnumValue.X", wantFound: false}, + {inName: "fizz.buzz.Message.Extension", wantFound: true}, + {inName: "fizz.buzz.Message.Extension.X", wantFound: false}, + {inName: "fizz.buzz.enum", wantFound: false}, + {inName: "fizz.buzz.Enum", wantFound: true}, + {inName: "fizz.buzz.Enum.X", wantFound: false}, + {inName: "fizz.buzz.EnumValue", wantFound: true}, + {inName: "fizz.buzz.EnumValue.X", wantFound: false}, + {inName: "fizz.buzz.Enum.EnumValue", wantFound: false}, + {inName: "fizz.buzz.Extension", wantFound: true}, + {inName: "fizz.buzz.Extension.X", wantFound: false}, + {inName: "fizz.buzz.service", wantFound: false}, + {inName: "fizz.buzz.Service", wantFound: true}, + {inName: "fizz.buzz.Service.X", wantFound: false}, + {inName: "fizz.buzz.Method", wantFound: false}, + {inName: "fizz.buzz.Service.Method", wantFound: true}, + {inName: "fizz.buzz.Service.Method.X", wantFound: false}, + + {inName: "fizz.buzz.gazz", wantFound: false}, + {inName: "fizz.buzz.gazz.Enum", wantFound: true}, + {inName: "fizz.buzz.gazz.EnumValue", wantFound: true}, + {inName: "fizz.buzz.gazz.Enum.EnumValue", wantFound: false}, + + {inName: "fizz.buzz", wantFound: false}, + {inName: "fizz.buzz.Enum1", wantFound: true}, + {inName: "fizz.buzz.EnumValue1", wantFound: true}, + {inName: "fizz.buzz.Enum1.EnumValue1", wantFound: false}, + {inName: "fizz.buzz.Enum2", wantFound: true}, + {inName: "fizz.buzz.EnumValue2", wantFound: true}, + {inName: "fizz.buzz.Enum2.EnumValue2", wantFound: false}, + {inName: "fizz.buzz.Enum3", wantFound: false}, + + {inName: "", wantFound: false}, + {inName: "Message", wantFound: true}, + {inName: "Message.Message", wantFound: true}, + {inName: "Message.Message.Message", wantFound: true}, + {inName: "Message.Message.Message.Message", wantFound: false}, + }, + }} + + sortFiles := cmpopts.SortSlices(func(x, y file) bool { + return x.Path < y.Path || (x.Path == y.Path && x.Pkg < y.Pkg) + }) + for _, tt := range tests { + t.Run("", func(t *testing.T) { + var files preg.Files + for i, tc := range tt.files { + gotErr := files.RegisterFile(tc.inFile) + if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) { + t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr) + } + } + + for _, tc := range tt.findDescs { + d, _ := files.FindDescriptorByName(tc.inName) + gotFound := d != nil + if gotFound != tc.wantFound { + t.Errorf("FindDescriptorByName(%v) find mismatch: got %v, want %v", tc.inName, gotFound, tc.wantFound) + } + } + + for _, tc := range tt.rangePkgs { + var gotFiles []file + var gotCnt int + wantCnt := files.NumFilesByPackage(tc.inPkg) + files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool { + gotFiles = append(gotFiles, file{fd.Path(), fd.Package()}) + gotCnt++ + return true + }) + if gotCnt != wantCnt { + t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt) + } + if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" { + t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff) + } + } + + for _, tc := range tt.findPaths { + var gotFiles []file + fd, gotErr := files.FindFileByPath(tc.inPath) + if gotErr == nil { + gotFiles = append(gotFiles, file{fd.Path(), fd.Package()}) + } + if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) { + t.Errorf("FindFileByPath(%v) = %v, want %v", tc.inPath, gotErr, tc.wantErr) + } + if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" { + t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff) + } + } + }) + } +} + +func TestTypes(t *testing.T) { + mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{}) + et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE) + xt1 := testpb.E_StringField + xt2 := testpb.E_Message4_MessageField + registry := new(preg.Types) + if err := registry.RegisterMessage(mt1); err != nil { + t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err) + } + if err := registry.RegisterEnum(et1); err != nil { + t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err) + } + if err := registry.RegisterExtension(xt1); err != nil { + t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err) + } + if err := registry.RegisterExtension(xt2); err != nil { + t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err) + } + + t.Run("FindMessageByName", func(t *testing.T) { + tests := []struct { + name string + messageType pref.MessageType + wantErr bool + wantNotFound bool + }{{ + name: "testprotos.Message1", + messageType: mt1, + }, { + name: "testprotos.NoSuchMessage", + wantErr: true, + wantNotFound: true, + }, { + name: "testprotos.Enum1", + wantErr: true, + }, { + name: "testprotos.Enum2", + wantErr: true, + }, { + name: "testprotos.Enum3", + wantErr: true, + }} + for _, tc := range tests { + got, err := registry.FindMessageByName(pref.FullName(tc.name)) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Errorf("FindMessageByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr) + continue + } + if tc.wantNotFound && err != preg.NotFound { + t.Errorf("FindMessageByName(%v) got error: %v, want NotFound error", tc.name, err) + continue + } + if got != tc.messageType { + t.Errorf("FindMessageByName(%v) got wrong value: %v", tc.name, got) + } + } + }) + + t.Run("FindMessageByURL", func(t *testing.T) { + tests := []struct { + name string + messageType pref.MessageType + wantErr bool + wantNotFound bool + }{{ + name: "testprotos.Message1", + messageType: mt1, + }, { + name: "type.googleapis.com/testprotos.Nada", + wantErr: true, + wantNotFound: true, + }, { + name: "testprotos.Enum1", + wantErr: true, + }} + for _, tc := range tests { + got, err := registry.FindMessageByURL(tc.name) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Errorf("FindMessageByURL(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr) + continue + } + if tc.wantNotFound && err != preg.NotFound { + t.Errorf("FindMessageByURL(%v) got error: %v, want NotFound error", tc.name, err) + continue + } + if got != tc.messageType { + t.Errorf("FindMessageByURL(%v) got wrong value: %v", tc.name, got) + } + } + }) + + t.Run("FindEnumByName", func(t *testing.T) { + tests := []struct { + name string + enumType pref.EnumType + wantErr bool + wantNotFound bool + }{{ + name: "testprotos.Enum1", + enumType: et1, + }, { + name: "testprotos.None", + wantErr: true, + wantNotFound: true, + }, { + name: "testprotos.Message1", + wantErr: true, + }} + for _, tc := range tests { + got, err := registry.FindEnumByName(pref.FullName(tc.name)) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Errorf("FindEnumByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr) + continue + } + if tc.wantNotFound && err != preg.NotFound { + t.Errorf("FindEnumByName(%v) got error: %v, want NotFound error", tc.name, err) + continue + } + if got != tc.enumType { + t.Errorf("FindEnumByName(%v) got wrong value: %v", tc.name, got) + } + } + }) + + t.Run("FindExtensionByName", func(t *testing.T) { + tests := []struct { + name string + extensionType pref.ExtensionType + wantErr bool + wantNotFound bool + }{{ + name: "testprotos.string_field", + extensionType: xt1, + }, { + name: "testprotos.Message4.message_field", + extensionType: xt2, + }, { + name: "testprotos.None", + wantErr: true, + wantNotFound: true, + }, { + name: "testprotos.Message1", + wantErr: true, + }} + for _, tc := range tests { + got, err := registry.FindExtensionByName(pref.FullName(tc.name)) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Errorf("FindExtensionByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr) + continue + } + if tc.wantNotFound && err != preg.NotFound { + t.Errorf("FindExtensionByName(%v) got error: %v, want NotFound error", tc.name, err) + continue + } + if got != tc.extensionType { + t.Errorf("FindExtensionByName(%v) got wrong value: %v", tc.name, got) + } + } + }) + + t.Run("FindExtensionByNumber", func(t *testing.T) { + tests := []struct { + parent string + number int32 + extensionType pref.ExtensionType + wantErr bool + wantNotFound bool + }{{ + parent: "testprotos.Message1", + number: 11, + extensionType: xt1, + }, { + parent: "testprotos.Message1", + number: 13, + wantErr: true, + wantNotFound: true, + }, { + parent: "testprotos.Message1", + number: 21, + extensionType: xt2, + }, { + parent: "testprotos.Message1", + number: 23, + wantErr: true, + wantNotFound: true, + }, { + parent: "testprotos.NoSuchMessage", + number: 11, + wantErr: true, + wantNotFound: true, + }, { + parent: "testprotos.Message1", + number: 30, + wantErr: true, + wantNotFound: true, + }, { + parent: "testprotos.Message1", + number: 99, + wantErr: true, + wantNotFound: true, + }} + for _, tc := range tests { + got, err := registry.FindExtensionByNumber(pref.FullName(tc.parent), pref.FieldNumber(tc.number)) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Errorf("FindExtensionByNumber(%v, %d) = (_, %v), want error? %t", tc.parent, tc.number, err, tc.wantErr) + continue + } + if tc.wantNotFound && err != preg.NotFound { + t.Errorf("FindExtensionByNumber(%v, %d) got error %v, want NotFound error", tc.parent, tc.number, err) + continue + } + if got != tc.extensionType { + t.Errorf("FindExtensionByNumber(%v, %d) got wrong value: %v", tc.parent, tc.number, got) + } + } + }) + + sortTypes := cmp.Options{ + cmpopts.SortSlices(func(x, y pref.EnumType) bool { + return x.Descriptor().FullName() < y.Descriptor().FullName() + }), + cmpopts.SortSlices(func(x, y pref.MessageType) bool { + return x.Descriptor().FullName() < y.Descriptor().FullName() + }), + cmpopts.SortSlices(func(x, y pref.ExtensionType) bool { + return x.TypeDescriptor().FullName() < y.TypeDescriptor().FullName() + }), + } + compare := cmp.Options{ + cmp.Comparer(func(x, y pref.EnumType) bool { + return x == y + }), + cmp.Comparer(func(x, y pref.ExtensionType) bool { + return x == y + }), + cmp.Comparer(func(x, y pref.MessageType) bool { + return x == y + }), + } + + t.Run("RangeEnums", func(t *testing.T) { + want := []pref.EnumType{et1} + var got []pref.EnumType + var gotCnt int + wantCnt := registry.NumEnums() + registry.RangeEnums(func(et pref.EnumType) bool { + got = append(got, et) + gotCnt++ + return true + }) + + if gotCnt != wantCnt { + t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt) + } + if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" { + t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff) + } + }) + + t.Run("RangeMessages", func(t *testing.T) { + want := []pref.MessageType{mt1} + var got []pref.MessageType + var gotCnt int + wantCnt := registry.NumMessages() + registry.RangeMessages(func(mt pref.MessageType) bool { + got = append(got, mt) + gotCnt++ + return true + }) + + if gotCnt != wantCnt { + t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt) + } + if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" { + t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff) + } + }) + + t.Run("RangeExtensions", func(t *testing.T) { + want := []pref.ExtensionType{xt1, xt2} + var got []pref.ExtensionType + var gotCnt int + wantCnt := registry.NumExtensions() + registry.RangeExtensions(func(xt pref.ExtensionType) bool { + got = append(got, xt) + gotCnt++ + return true + }) + + if gotCnt != wantCnt { + t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt) + } + if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" { + t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff) + } + }) + + t.Run("RangeExtensionsByMessage", func(t *testing.T) { + want := []pref.ExtensionType{xt1, xt2} + var got []pref.ExtensionType + var gotCnt int + wantCnt := registry.NumExtensionsByMessage("testprotos.Message1") + registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool { + got = append(got, xt) + gotCnt++ + return true + }) + + if gotCnt != wantCnt { + t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt) + } + if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" { + t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff) + } + }) +} |