diff options
Diffstat (limited to 'compiler/protogen/protogen.go')
-rw-r--r-- | compiler/protogen/protogen.go | 100 |
1 files changed, 98 insertions, 2 deletions
diff --git a/compiler/protogen/protogen.go b/compiler/protogen/protogen.go index 2ee676fb..2d2171e5 100644 --- a/compiler/protogen/protogen.go +++ b/compiler/protogen/protogen.go @@ -36,10 +36,11 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/pluginpb" ) -const goPackageDocURL = "https://developers.google.com/protocol-buffers/docs/reference/go-generated#package" +const goPackageDocURL = "https://protobuf.dev/reference/go/go-generated#package" // Run executes a function as a protoc plugin. // @@ -209,6 +210,7 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { } } } + // When the module= option is provided, we strip the module name // prefix from generated files. This only makes sense if generated // filenames are based on the import path. @@ -298,6 +300,8 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { } } + // The extracted types from the full import set + typeRegistry := newExtensionRegistry() for _, fdesc := range gen.Request.ProtoFile { filename := fdesc.GetName() if gen.FilesByPath[filename] != nil { @@ -309,6 +313,9 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { } gen.Files = append(gen.Files, f) gen.FilesByPath[filename] = f + if err = typeRegistry.registerAllExtensionsFromFile(f.Desc); err != nil { + return nil, err + } } for _, filename := range gen.Request.FileToGenerate { f, ok := gen.FilesByPath[filename] @@ -317,6 +324,20 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { } f.Generate = true } + + // Create fully-linked descriptors if new extensions were found + if typeRegistry.hasNovelExtensions() { + for _, f := range gen.Files { + b, err := proto.Marshal(f.Proto.ProtoReflect().Interface()) + if err != nil { + return nil, err + } + err = proto.UnmarshalOptions{Resolver: typeRegistry}.Unmarshal(b, f.Proto) + if err != nil { + return nil, err + } + } + } return gen, nil } @@ -472,7 +493,7 @@ func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPac } // splitImportPathAndPackageName splits off the optional Go package name -// from the Go import path when seperated by a ';' delimiter. +// from the Go import path when separated by a ';' delimiter. func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) { if i := strings.Index(s, ";"); i >= 0 { return GoImportPath(s[:i]), GoPackageName(s[i+1:]) @@ -1259,3 +1280,78 @@ func (c Comments) String() string { } return string(b) } + +// extensionRegistry allows registration of new extensions defined in the .proto +// file for which we are generating bindings. +// +// Lookups consult the local type registry first and fall back to the base type +// registry which defaults to protoregistry.GlobalTypes +type extensionRegistry struct { + base *protoregistry.Types + local *protoregistry.Types +} + +func newExtensionRegistry() *extensionRegistry { + return &extensionRegistry{ + base: protoregistry.GlobalTypes, + local: &protoregistry.Types{}, + } +} + +// FindExtensionByName implements proto.UnmarshalOptions.FindExtensionByName +func (e *extensionRegistry) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { + if xt, err := e.local.FindExtensionByName(field); err == nil { + return xt, nil + } + + return e.base.FindExtensionByName(field) +} + +// FindExtensionByNumber implements proto.UnmarshalOptions.FindExtensionByNumber +func (e *extensionRegistry) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + if xt, err := e.local.FindExtensionByNumber(message, field); err == nil { + return xt, nil + } + + return e.base.FindExtensionByNumber(message, field) +} + +func (e *extensionRegistry) hasNovelExtensions() bool { + return e.local.NumExtensions() > 0 +} + +func (e *extensionRegistry) registerAllExtensionsFromFile(f protoreflect.FileDescriptor) error { + if err := e.registerAllExtensions(f.Extensions()); err != nil { + return err + } + return nil +} + +func (e *extensionRegistry) registerAllExtensionsFromMessage(ms protoreflect.MessageDescriptors) error { + for i := 0; i < ms.Len(); i++ { + m := ms.Get(i) + if err := e.registerAllExtensions(m.Extensions()); err != nil { + return err + } + } + return nil +} + +func (e *extensionRegistry) registerAllExtensions(exts protoreflect.ExtensionDescriptors) error { + for i := 0; i < exts.Len(); i++ { + if err := e.registerExtension(exts.Get(i)); err != nil { + return err + } + } + return nil +} + +// registerExtension adds the given extension to the type registry if an +// extension with that full name does not exist yet. +func (e *extensionRegistry) registerExtension(xd protoreflect.ExtensionDescriptor) error { + if _, err := e.FindExtensionByName(xd.FullName()); err != protoregistry.NotFound { + // Either the extension already exists or there was an error, either way we're done. + return err + } + return e.local.RegisterExtension(dynamicpb.NewExtensionType(xd)) +} |