aboutsummaryrefslogtreecommitdiff
path: root/compiler/protogen/protogen.go
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/protogen/protogen.go')
-rw-r--r--compiler/protogen/protogen.go100
1 files changed, 98 insertions, 2 deletions
diff --git a/compiler/protogen/protogen.go b/compiler/protogen/protogen.go
index 2ee676fb..2d2171e5 100644
--- a/compiler/protogen/protogen.go
+++ b/compiler/protogen/protogen.go
@@ -36,10 +36,11 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
+ "google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/pluginpb"
)
-const goPackageDocURL = "https://developers.google.com/protocol-buffers/docs/reference/go-generated#package"
+const goPackageDocURL = "https://protobuf.dev/reference/go/go-generated#package"
// Run executes a function as a protoc plugin.
//
@@ -209,6 +210,7 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
}
}
}
+
// When the module= option is provided, we strip the module name
// prefix from generated files. This only makes sense if generated
// filenames are based on the import path.
@@ -298,6 +300,8 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
}
}
+ // The extracted types from the full import set
+ typeRegistry := newExtensionRegistry()
for _, fdesc := range gen.Request.ProtoFile {
filename := fdesc.GetName()
if gen.FilesByPath[filename] != nil {
@@ -309,6 +313,9 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
}
gen.Files = append(gen.Files, f)
gen.FilesByPath[filename] = f
+ if err = typeRegistry.registerAllExtensionsFromFile(f.Desc); err != nil {
+ return nil, err
+ }
}
for _, filename := range gen.Request.FileToGenerate {
f, ok := gen.FilesByPath[filename]
@@ -317,6 +324,20 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
}
f.Generate = true
}
+
+ // Create fully-linked descriptors if new extensions were found
+ if typeRegistry.hasNovelExtensions() {
+ for _, f := range gen.Files {
+ b, err := proto.Marshal(f.Proto.ProtoReflect().Interface())
+ if err != nil {
+ return nil, err
+ }
+ err = proto.UnmarshalOptions{Resolver: typeRegistry}.Unmarshal(b, f.Proto)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
return gen, nil
}
@@ -472,7 +493,7 @@ func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPac
}
// splitImportPathAndPackageName splits off the optional Go package name
-// from the Go import path when seperated by a ';' delimiter.
+// from the Go import path when separated by a ';' delimiter.
func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) {
if i := strings.Index(s, ";"); i >= 0 {
return GoImportPath(s[:i]), GoPackageName(s[i+1:])
@@ -1259,3 +1280,78 @@ func (c Comments) String() string {
}
return string(b)
}
+
+// extensionRegistry allows registration of new extensions defined in the .proto
+// file for which we are generating bindings.
+//
+// Lookups consult the local type registry first and fall back to the base type
+// registry which defaults to protoregistry.GlobalTypes
+type extensionRegistry struct {
+ base *protoregistry.Types
+ local *protoregistry.Types
+}
+
+func newExtensionRegistry() *extensionRegistry {
+ return &extensionRegistry{
+ base: protoregistry.GlobalTypes,
+ local: &protoregistry.Types{},
+ }
+}
+
+// FindExtensionByName implements proto.UnmarshalOptions.FindExtensionByName
+func (e *extensionRegistry) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
+ if xt, err := e.local.FindExtensionByName(field); err == nil {
+ return xt, nil
+ }
+
+ return e.base.FindExtensionByName(field)
+}
+
+// FindExtensionByNumber implements proto.UnmarshalOptions.FindExtensionByNumber
+func (e *extensionRegistry) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
+ if xt, err := e.local.FindExtensionByNumber(message, field); err == nil {
+ return xt, nil
+ }
+
+ return e.base.FindExtensionByNumber(message, field)
+}
+
+func (e *extensionRegistry) hasNovelExtensions() bool {
+ return e.local.NumExtensions() > 0
+}
+
+func (e *extensionRegistry) registerAllExtensionsFromFile(f protoreflect.FileDescriptor) error {
+ if err := e.registerAllExtensions(f.Extensions()); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (e *extensionRegistry) registerAllExtensionsFromMessage(ms protoreflect.MessageDescriptors) error {
+ for i := 0; i < ms.Len(); i++ {
+ m := ms.Get(i)
+ if err := e.registerAllExtensions(m.Extensions()); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (e *extensionRegistry) registerAllExtensions(exts protoreflect.ExtensionDescriptors) error {
+ for i := 0; i < exts.Len(); i++ {
+ if err := e.registerExtension(exts.Get(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// registerExtension adds the given extension to the type registry if an
+// extension with that full name does not exist yet.
+func (e *extensionRegistry) registerExtension(xd protoreflect.ExtensionDescriptor) error {
+ if _, err := e.FindExtensionByName(xd.FullName()); err != protoregistry.NotFound {
+ // Either the extension already exists or there was an error, either way we're done.
+ return err
+ }
+ return e.local.RegisterExtension(dynamicpb.NewExtensionType(xd))
+}