aboutsummaryrefslogtreecommitdiff
path: root/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_protobuf/py/pw_protobuf/codegen_pwpb.py')
-rw-r--r--pw_protobuf/py/pw_protobuf/codegen_pwpb.py742
1 files changed, 656 insertions, 86 deletions
diff --git a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
index dd862645b..261960787 100644
--- a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
+++ b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Pigweed Authors
+# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
@@ -21,7 +21,7 @@ from graphlib import CycleError, TopologicalSorter # type: ignore
from itertools import takewhile
import os
import sys
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, Type
from typing import cast
from google.protobuf import descriptor_pb2
@@ -346,6 +346,47 @@ class PackedReadVectorMethod(ReadMethod):
return [('::pw::Vector<{}>&'.format(self._result_type()), 'out')]
+class FindMethod(ReadMethod):
+ def name(self) -> str:
+ return 'Find{}'.format(self._field.name())
+
+ def params(self) -> List[Tuple[str, str]]:
+ return [('::pw::ConstByteSpan', 'message')]
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message, {self.field_cast()});'
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ """The find function to call.
+
+ Defined in subclasses.
+
+ e.g. 'FindUint32', 'FindBytes', etc.
+ """
+ raise NotImplementedError()
+
+
+class FindStreamMethod(FindMethod):
+ def name(self) -> str:
+ return 'Find{}'.format(self._field.name())
+
+ def params(self) -> List[Tuple[str, str]]:
+ return [('::pw::stream::Reader&', 'message_stream')]
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message_stream, {self.field_cast()});'
+ ]
+ return lines
+
+
class MessageProperty(ProtoMember):
"""Base class for a C++ property for a field in a protobuf message."""
@@ -375,7 +416,7 @@ class MessageProperty(ProtoMember):
return False
@staticmethod
- def repeated_field_container(type_name: str, max_size: int) -> str:
+ def repeated_field_container(type_name: str, max_size: str) -> str:
"""Returns the container type used for repeated fields.
Defaults to ::pw::Vector<type, max_size>. String fields use
@@ -423,42 +464,36 @@ class MessageProperty(ProtoMember):
def sub_table(self) -> str: # pylint: disable=no-self-use
return '{}'
- def struct_member(self, from_root: bool = False) -> Tuple[str, str]:
- """Returns the structure member."""
+ def struct_member_type(self, from_root: bool = False) -> str:
+ """Returns the structure member type."""
if self.use_callback():
return (
- f'{PROTOBUF_NAMESPACE}::Callback'
- '<StreamEncoder, StreamDecoder>',
- self.name(),
+ f'{PROTOBUF_NAMESPACE}::Callback<StreamEncoder, StreamDecoder>'
)
# Optional fields are wrapped in std::optional
if self.is_optional():
- return (
- 'std::optional<{}>'.format(self.type_name(from_root)),
- self.name(),
- )
+ return 'std::optional<{}>'.format(self.type_name(from_root))
# Non-repeated fields have a member of just the type name.
max_size = self.max_size()
if max_size == 0:
- return (self.type_name(from_root), self.name())
+ return self.type_name(from_root)
# Fixed size fields use std::array.
if self.is_fixed_size():
- return (
- 'std::array<{}, {}>'.format(
- self.type_name(from_root), max_size
- ),
- self.name(),
+ return 'std::array<{}, {}>'.format(
+ self.type_name(from_root), self.max_size_constant_name()
)
# Otherwise prefer pw::Vector for repeated fields.
- return (
- self.repeated_field_container(self.type_name(from_root), max_size),
- self.name(),
+ return self.repeated_field_container(
+ self.type_name(from_root), self.max_size_constant_name()
)
+ def max_size_constant_name(self) -> str:
+ return f'k{self._field.name()}MaxSize'
+
def _varint_type_table_entry(self) -> str:
if self.wire_type() == 'kVarint':
return '{}::VarintType::{}'.format(
@@ -574,6 +609,16 @@ class SubMessageDecoderMethod(ReadMethod):
return False
+class SubMessageFindMethod(FindMethod):
+ """Method which reads a proto submessage."""
+
+ def _result_type(self) -> str:
+ return '::pw::ConstByteSpan'
+
+ def _find_fn(self) -> str:
+ return 'FindBytes'
+
+
class SubMessageProperty(MessageProperty):
"""Property which contains a sub-message."""
@@ -716,6 +761,26 @@ class PackedDoubleReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedDouble'
+class DoubleFindMethod(FindMethod):
+ """Method which reads a proto double value."""
+
+ def _result_type(self) -> str:
+ return 'double'
+
+ def _find_fn(self) -> str:
+ return 'FindDouble'
+
+
+class DoubleFindStreamMethod(FindStreamMethod):
+ """Method which reads a proto double value."""
+
+ def _result_type(self) -> str:
+ return 'double'
+
+ def _find_fn(self) -> str:
+ return 'FindDouble'
+
+
class DoubleProperty(MessageProperty):
"""Property which holds a proto double value."""
@@ -789,6 +854,26 @@ class PackedFloatReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedFloat'
+class FloatFindMethod(FindMethod):
+ """Method which reads a proto float value."""
+
+ def _result_type(self) -> str:
+ return 'float'
+
+ def _find_fn(self) -> str:
+ return 'FindFloat'
+
+
+class FloatFindStreamMethod(FindStreamMethod):
+ """Method which reads a proto float value."""
+
+ def _result_type(self) -> str:
+ return 'float'
+
+ def _find_fn(self) -> str:
+ return 'FindFloat'
+
+
class FloatProperty(MessageProperty):
"""Property which holds a proto float value."""
@@ -862,6 +947,26 @@ class PackedInt32ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedInt32'
+class Int32FindMethod(FindMethod):
+ """Method which reads a proto int32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindInt32'
+
+
+class Int32FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto int32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindInt32'
+
+
class Int32Property(MessageProperty):
"""Property which holds a proto int32 value."""
@@ -938,6 +1043,26 @@ class PackedSint32ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedSint32'
+class Sint32FindMethod(FindMethod):
+ """Method which reads a proto sint32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSint32'
+
+
+class Sint32FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto sint32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSint32'
+
+
class Sint32Property(MessageProperty):
"""Property which holds a proto sint32 value."""
@@ -1014,6 +1139,26 @@ class PackedSfixed32ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedSfixed32'
+class Sfixed32FindMethod(FindMethod):
+ """Method which reads a proto sfixed32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSfixed32'
+
+
+class Sfixed32FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto sfixed32 value."""
+
+ def _result_type(self) -> str:
+ return 'int32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSfixed32'
+
+
class Sfixed32Property(MessageProperty):
"""Property which holds a proto sfixed32 value."""
@@ -1087,6 +1232,26 @@ class PackedInt64ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedInt64'
+class Int64FindMethod(FindMethod):
+ """Method which reads a proto int64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindInt64'
+
+
+class Int64FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto int64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindInt64'
+
+
class Int64Property(MessageProperty):
"""Property which holds a proto int64 value."""
@@ -1163,6 +1328,26 @@ class PackedSint64ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedSint64'
+class Sint64FindMethod(FindMethod):
+ """Method which reads a proto sint64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSint64'
+
+
+class Sint64FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto sint64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSint64'
+
+
class Sint64Property(MessageProperty):
"""Property which holds a proto sint64 value."""
@@ -1239,6 +1424,26 @@ class PackedSfixed64ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedSfixed64'
+class Sfixed64FindMethod(FindMethod):
+ """Method which reads a proto sfixed64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSfixed64'
+
+
+class Sfixed64FindStreamMethod(FindStreamMethod):
+ """Method which reads a proto sfixed64 value."""
+
+ def _result_type(self) -> str:
+ return 'int64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindSfixed64'
+
+
class Sfixed64Property(MessageProperty):
"""Property which holds a proto sfixed64 value."""
@@ -1312,6 +1517,26 @@ class PackedUint32ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedUint32'
+class Uint32FindMethod(FindMethod):
+ """Method which finds a proto uint32 value."""
+
+ def _result_type(self) -> str:
+ return 'uint32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindUint32'
+
+
+class Uint32FindStreamMethod(FindStreamMethod):
+ """Method which finds a proto uint32 value."""
+
+ def _result_type(self) -> str:
+ return 'uint32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindUint32'
+
+
class Uint32Property(MessageProperty):
"""Property which holds a proto uint32 value."""
@@ -1388,6 +1613,26 @@ class PackedFixed32ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedFixed32'
+class Fixed32FindMethod(FindMethod):
+ """Method which finds a proto fixed32 value."""
+
+ def _result_type(self) -> str:
+ return 'uint32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindFixed32'
+
+
+class Fixed32FindStreamMethod(FindStreamMethod):
+ """Method which finds a proto fixed32 value."""
+
+ def _result_type(self) -> str:
+ return 'uint32_t'
+
+ def _find_fn(self) -> str:
+ return 'FindFixed32'
+
+
class Fixed32Property(MessageProperty):
"""Property which holds a proto fixed32 value."""
@@ -1461,6 +1706,26 @@ class PackedUint64ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedUint64'
+class Uint64FindMethod(FindMethod):
+ """Method which finds a proto uint64 value."""
+
+ def _result_type(self) -> str:
+ return 'uint64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindUint64'
+
+
+class Uint64FindStreamMethod(FindStreamMethod):
+ """Method which finds a proto uint64 value."""
+
+ def _result_type(self) -> str:
+ return 'uint64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindUint64'
+
+
class Uint64Property(MessageProperty):
"""Property which holds a proto uint64 value."""
@@ -1537,6 +1802,26 @@ class PackedFixed64ReadVectorMethod(PackedReadVectorMethod):
return 'ReadRepeatedFixed64'
+class Fixed64FindMethod(FindMethod):
+ """Method which finds a proto fixed64 value."""
+
+ def _result_type(self) -> str:
+ return 'uint64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindFixed64'
+
+
+class Fixed64FindStreamMethod(FindStreamMethod):
+ """Method which finds a proto fixed64 value."""
+
+ def _result_type(self) -> str:
+ return 'uint64_t'
+
+ def _find_fn(self) -> str:
+ return 'FindFixed64'
+
+
class Fixed64Property(MessageProperty):
"""Property which holds a proto fixed64 value."""
@@ -1600,6 +1885,26 @@ class PackedBoolReadMethod(PackedReadMethod):
return 'ReadPackedBool'
+class BoolFindMethod(FindMethod):
+ """Method which finds a proto bool value."""
+
+ def _result_type(self) -> str:
+ return 'bool'
+
+ def _find_fn(self) -> str:
+ return 'FindBool'
+
+
+class BoolFindStreamMethod(FindStreamMethod):
+ """Method which finds a proto bool value."""
+
+ def _result_type(self) -> str:
+ return 'bool'
+
+ def _find_fn(self) -> str:
+ return 'FindBool'
+
+
class BoolProperty(MessageProperty):
"""Property which holds a proto bool value."""
@@ -1639,6 +1944,40 @@ class BytesReadMethod(ReadMethod):
return 'ReadBytes'
+class BytesFindMethod(FindMethod):
+ """Method which reads a proto bytes value."""
+
+ def _result_type(self) -> str:
+ return '::pw::ConstByteSpan'
+
+ def _find_fn(self) -> str:
+ return 'FindBytes'
+
+
+class BytesFindStreamMethod(FindStreamMethod):
+ """Method which reads a proto bytes value."""
+
+ def return_type(self, from_root: bool = False) -> str:
+ return '::pw::StatusWithSize'
+
+ def params(self) -> List[Tuple[str, str]]:
+ return [
+ ('::pw::stream::Reader&', 'message_stream'),
+ ('::pw::ByteSpan', 'out'),
+ ]
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message_stream, {self.field_cast()}, out);'
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ return 'FindBytes'
+
+
class BytesProperty(MessageProperty):
"""Property which holds a proto bytes value."""
@@ -1676,7 +2015,7 @@ class BytesProperty(MessageProperty):
def _size_length(self) -> Optional[str]:
if self.use_callback():
return None
- return f'{self.max_size()}'
+ return self.max_size_constant_name()
class StringLenWriteMethod(WriteMethod):
@@ -1712,6 +2051,64 @@ class StringReadMethod(ReadMethod):
return 'ReadString'
+class StringFindMethod(FindMethod):
+ """Method which reads a proto string value."""
+
+ def _result_type(self) -> str:
+ return 'std::string_view'
+
+ def _find_fn(self) -> str:
+ return 'FindString'
+
+
+class StringFindStreamMethod(FindStreamMethod):
+ """Method which reads a proto string value."""
+
+ def return_type(self, from_root: bool = False) -> str:
+ return '::pw::StatusWithSize'
+
+ def params(self) -> List[Tuple[str, str]]:
+ return [
+ ('::pw::stream::Reader&', 'message_stream'),
+ ('::pw::span<char>', 'out'),
+ ]
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message_stream, {self.field_cast()}, out);'
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ return 'FindString'
+
+
+class StringFindStreamMethodInlineString(FindStreamMethod):
+ """Method which reads a proto string value to an InlineString."""
+
+ def return_type(self, from_root: bool = False) -> str:
+ return '::pw::StatusWithSize'
+
+ def params(self) -> List[Tuple[str, str]]:
+ return [
+ ('::pw::stream::Reader&', 'message_stream'),
+ ('::pw::InlineString<>&', 'out'),
+ ]
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message_stream, {self.field_cast()}, out);'
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ return 'FindString'
+
+
class StringProperty(MessageProperty):
"""Property which holds a proto string value."""
@@ -1739,7 +2136,7 @@ class StringProperty(MessageProperty):
return True
@staticmethod
- def repeated_field_container(type_name: str, max_size: int) -> str:
+ def repeated_field_container(type_name: str, max_size: str) -> str:
return f'::pw::InlineBasicString<{type_name}, {max_size}>'
def _size_fn(self) -> str:
@@ -1751,7 +2148,7 @@ class StringProperty(MessageProperty):
def _size_length(self) -> Optional[str]:
if self.use_callback():
return None
- return f'{self.max_size()}'
+ return self.max_size_constant_name()
class EnumWriteMethod(WriteMethod):
@@ -1863,6 +2260,52 @@ class PackedEnumReadVectorMethod(PackedReadVectorMethod):
]
+class EnumFindMethod(FindMethod):
+ """Method which finds a proto enum value."""
+
+ def _result_type(self) -> str:
+ return self._relative_type_namespace()
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ '::pw::Result<uint32_t> result = '
+ f'{PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message, {self.field_cast()});',
+ 'if (!result.ok()) {',
+ ' return result.status();',
+ '}',
+ f'return static_cast<{self._result_type()}>(result.value());',
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ return 'FindUint32'
+
+
+class EnumFindStreamMethod(FindStreamMethod):
+ """Method which finds a proto enum value."""
+
+ def _result_type(self) -> str:
+ return self._relative_type_namespace()
+
+ def body(self) -> List[str]:
+ lines: List[str] = []
+ lines += [
+ '::pw::Result<uint32_t> result = '
+ f'{PROTOBUF_NAMESPACE}::{self._find_fn()}'
+ f'(message_stream, {self.field_cast()});',
+ 'if (!result.ok()) {',
+ ' return result.status();',
+ '}',
+ f'return static_cast<{self._result_type()}>(result.value());',
+ ]
+ return lines
+
+ def _find_fn(self) -> str:
+ return 'FindUint32'
+
+
class EnumProperty(MessageProperty):
"""Property which holds a proto enum value."""
@@ -2040,27 +2483,121 @@ PROTO_FIELD_READ_METHODS: Dict[int, List] = {
],
}
-PROTO_FIELD_PROPERTIES: Dict[int, List] = {
- descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [DoubleProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [FloatProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [Int32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [Sint32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [Sfixed32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [Int64Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [Sint64Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [Sfixed32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [Uint32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [Fixed32Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [Uint64Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [Fixed64Property],
- descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [BoolProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [StringProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageProperty],
- descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumProperty],
+PROTO_FIELD_FIND_METHODS: Dict[int, List] = {
+ descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [
+ DoubleFindMethod,
+ DoubleFindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [
+ FloatFindMethod,
+ FloatFindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [
+ Int32FindMethod,
+ Int32FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [
+ Sint32FindMethod,
+ Sint32FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [
+ Sfixed32FindMethod,
+ Sfixed32FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [
+ Int64FindMethod,
+ Int64FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [
+ Sint64FindMethod,
+ Sint64FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [
+ Sfixed64FindMethod,
+ Sfixed64FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [
+ Uint32FindMethod,
+ Uint32FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [
+ Fixed32FindMethod,
+ Fixed32FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [
+ Uint64FindMethod,
+ Uint64FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [
+ Fixed64FindMethod,
+ Fixed64FindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [
+ BoolFindMethod,
+ BoolFindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [
+ BytesFindMethod,
+ BytesFindStreamMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
+ StringFindMethod,
+ StringFindStreamMethod,
+ StringFindStreamMethodInlineString,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [
+ SubMessageFindMethod,
+ ],
+ descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [
+ EnumFindMethod,
+ EnumFindStreamMethod,
+ ],
+}
+
+PROTO_FIELD_PROPERTIES: Dict[int, Type[MessageProperty]] = {
+ descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: DoubleProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: FloatProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_INT32: Int32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: Sint32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: Sfixed32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_INT64: Int64Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: Sint64Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: Sfixed32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: Uint32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: Fixed32Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: Uint64Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: Fixed64Property,
+ descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: BoolProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: BytesProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_STRING: StringProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: SubMessageProperty,
+ descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: EnumProperty,
}
+def proto_message_field_props(
+ message: ProtoMessage,
+ root: ProtoNode,
+) -> Iterable[MessageProperty]:
+ """Yields a MessageProperty for each field in a ProtoMessage.
+
+ Only properties which should_appear() is True are returned.
+
+ Args:
+ message: The ProtoMessage whose fields are iterated.
+ root: The root ProtoNode of the tree.
+
+ Yields:
+ An appropriately-typed MessageProperty object for each field
+ in the message, to which the property refers.
+ """
+ for field in message.fields():
+ property_class = PROTO_FIELD_PROPERTIES[field.type()]
+ prop = property_class(field, message, root)
+ if prop.should_appear():
+ yield prop
+
+
def proto_field_methods(class_type: ClassType, field_type: int) -> List:
return (
PROTO_FIELD_WRITE_METHODS[field_type]
@@ -2316,31 +2853,41 @@ def generate_to_string_for_enum(
def forward_declare(
- node: ProtoMessage,
+ message: ProtoMessage,
root: ProtoNode,
output: OutputFile,
exclude_legacy_snake_case_field_name_enums: bool,
) -> None:
"""Generates code forward-declaring entities in a message's namespace."""
- namespace = node.cpp_namespace(root=root)
+ namespace = message.cpp_namespace(root=root)
output.write_line()
output.write_line(f'namespace {namespace} {{')
# Define an enum defining each of the message's fields and their numbers.
output.write_line('enum class Fields : uint32_t {')
with output.indent():
- for field in node.fields():
+ for field in message.fields():
output.write_line(f'{field.enum_name()} = {field.number()},')
# Migration support from SNAKE_CASE to kConstantCase.
if not exclude_legacy_snake_case_field_name_enums:
- for field in node.fields():
+ for field in message.fields():
output.write_line(
f'{field.legacy_enum_name()} = {field.number()},'
)
output.write_line('};')
+ # Define constants for fixed-size fields.
+ output.write_line()
+ for prop in proto_message_field_props(message, root):
+ max_size = prop.max_size()
+ if max_size:
+ output.write_line(
+ f'static constexpr size_t {prop.max_size_constant_name()} '
+ f'= {max_size};'
+ )
+
# Declare the message's message struct.
output.write_line()
output.write_line('struct Message;')
@@ -2355,14 +2902,14 @@ def forward_declare(
output.write_line('class StreamDecoder;')
# Declare the message's enums.
- for child in node.children():
+ for child in message.children():
if child.type() == ProtoNode.Type.ENUM:
output.write_line()
- generate_code_for_enum(cast(ProtoEnum, child), node, output)
+ generate_code_for_enum(cast(ProtoEnum, child), message, output)
output.write_line()
- generate_function_for_enum(cast(ProtoEnum, child), node, output)
+ generate_function_for_enum(cast(ProtoEnum, child), message, output)
output.write_line()
- generate_to_string_for_enum(cast(ProtoEnum, child), node, output)
+ generate_to_string_for_enum(cast(ProtoEnum, child), message, output)
output.write_line(f'}} // namespace {namespace}')
@@ -2378,17 +2925,13 @@ def generate_struct_for_message(
# Generate members for each of the message's fields.
with output.indent():
cmp: List[str] = []
- for field in message.fields():
- for property_class in PROTO_FIELD_PROPERTIES[field.type()]:
- prop = property_class(field, message, root)
- if not prop.should_appear():
- continue
-
- (type_name, name) = prop.struct_member()
- output.write_line(f'{type_name} {name};')
+ for prop in proto_message_field_props(message, root):
+ type_name = prop.struct_member_type()
+ name = prop.name()
+ output.write_line(f'{type_name} {name};')
- if not prop.use_callback():
- cmp.append(f'{name} == other.{name}')
+ if not prop.use_callback():
+ cmp.append(f'{name} == other.{name}')
# Equality operator
output.write_line()
@@ -2417,12 +2960,7 @@ def generate_table_for_message(
namespace = message.cpp_namespace(root=root)
output.write_line(f'namespace {namespace} {{')
- properties = []
- for field in message.fields():
- for property_class in PROTO_FIELD_PROPERTIES[field.type()]:
- prop = property_class(field, message, root)
- if prop.should_appear():
- properties.append(prop)
+ properties = list(proto_message_field_props(message, root))
output.write_line('PW_MODIFY_DIAGNOSTICS_PUSH();')
output.write_line('PW_MODIFY_DIAGNOSTIC(ignored, "-Winvalid-offsetof");')
@@ -2469,7 +3007,7 @@ def generate_table_for_message(
)
member_list = ', '.join(
- [f'message.{prop.struct_member()[1]}' for prop in properties]
+ [f'message.{prop.name()}' for prop in properties]
)
# Generate std::tuple for Message fields.
@@ -2505,15 +3043,10 @@ def generate_sizes_for_message(
property_sizes: List[str] = []
scratch_sizes: List[str] = []
- for field in message.fields():
- for property_class in PROTO_FIELD_PROPERTIES[field.type()]:
- prop = property_class(field, message, root)
- if not prop.should_appear():
- continue
-
- property_sizes.append(prop.max_encoded_size())
- if prop.include_in_scratch_size():
- scratch_sizes.append(prop.max_encoded_size())
+ for prop in proto_message_field_props(message, root):
+ property_sizes.append(prop.max_encoded_size())
+ if prop.include_in_scratch_size():
+ scratch_sizes.append(prop.max_encoded_size())
output.write_line('inline constexpr size_t kMaxEncodedSizeBytes =')
with output.indent():
@@ -2540,19 +3073,53 @@ def generate_sizes_for_message(
output.write_line(f'}} // namespace {namespace}')
-def generate_is_trivially_comparable_specialization(
+def generate_find_functions_for_message(
message: ProtoMessage, root: ProtoNode, output: OutputFile
) -> None:
- is_trivially_comparable = True
+ """Creates C++ constants for the encoded sizes of a protobuf message."""
+ assert message.type() == ProtoNode.Type.MESSAGE
+
+ namespace = message.cpp_namespace(root=root)
+ output.write_line(f'namespace {namespace} {{')
+
for field in message.fields():
- for property_class in PROTO_FIELD_PROPERTIES[field.type()]:
- prop = property_class(field, message, root)
- if not prop.should_appear():
- continue
+ if field.is_repeated():
+ # Find methods don't account for repeated field semantics, so
+ # ignore them to avoid confusion.
+ continue
+
+ try:
+ methods = PROTO_FIELD_FIND_METHODS[field.type()]
+ except KeyError:
+ continue
+
+ for cls in methods:
+ method = cls(field, message, root, '')
+ method_signature = (
+ f'inline {method.return_type()} '
+ f'{method.name()}({method.param_string()})'
+ )
+
+ output.write_line()
+ output.write_line(f'{method_signature} {{')
+
+ with output.indent():
+ for line in method.body():
+ output.write_line(line)
+
+ output.write_line('}')
+
+ output.write_line(f'}} // namespace {namespace}')
+
- if prop.use_callback():
- is_trivially_comparable = False
- break
+def generate_is_trivially_comparable_specialization(
+ message: ProtoMessage, root: ProtoNode, output: OutputFile
+) -> None:
+ is_trivially_comparable = True
+ for prop in proto_message_field_props(message, root):
+ if prop.use_callback():
+ is_trivially_comparable = False
+ break
qualified_message = f'{message.cpp_namespace()}::Message'
@@ -2626,6 +3193,7 @@ def generate_code_for_package(
output.write_line('#include "pw_containers/vector.h"')
output.write_line('#include "pw_preprocessor/compiler.h"')
output.write_line('#include "pw_protobuf/encoder.h"')
+ output.write_line('#include "pw_protobuf/find.h"')
output.write_line('#include "pw_protobuf/internal/codegen.h"')
output.write_line('#include "pw_protobuf/serialized_size.h"')
output.write_line('#include "pw_protobuf/stream_decoder.h"')
@@ -2675,6 +3243,8 @@ def generate_code_for_package(
output.write_line()
generate_sizes_for_message(message, package, output)
output.write_line()
+ generate_find_functions_for_message(message, package, output)
+ output.write_line()
generate_class_for_message(
message, package, output, ClassType.STREAMING_ENCODER
)
@@ -2720,7 +3290,7 @@ def generate_code_for_package(
output.write_line(f'using namespace ::{package.cpp_namespace()};')
output.write_line(f'}} // namespace {legacy_namespace}')
- # TODO(b/250945489) Remove this if possible
+ # TODO: b/250945489 - Remove this if possible
output.write_line()
output.write_line(
'// Codegen implementation detail; do not use this namespace!'