diff options
Diffstat (limited to 'pw_protobuf/py/pw_protobuf/codegen_pwpb.py')
-rw-r--r-- | pw_protobuf/py/pw_protobuf/codegen_pwpb.py | 742 |
1 files changed, 656 insertions, 86 deletions
diff --git a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py index dd862645b..261960787 100644 --- a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py +++ b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Pigweed Authors +# Copyright 2023 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -21,7 +21,7 @@ from graphlib import CycleError, TopologicalSorter # type: ignore from itertools import takewhile import os import sys -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Type from typing import cast from google.protobuf import descriptor_pb2 @@ -346,6 +346,47 @@ class PackedReadVectorMethod(ReadMethod): return [('::pw::Vector<{}>&'.format(self._result_type()), 'out')] +class FindMethod(ReadMethod): + def name(self) -> str: + return 'Find{}'.format(self._field.name()) + + def params(self) -> List[Tuple[str, str]]: + return [('::pw::ConstByteSpan', 'message')] + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message, {self.field_cast()});' + ] + return lines + + def _find_fn(self) -> str: + """The find function to call. + + Defined in subclasses. + + e.g. 'FindUint32', 'FindBytes', etc. + """ + raise NotImplementedError() + + +class FindStreamMethod(FindMethod): + def name(self) -> str: + return 'Find{}'.format(self._field.name()) + + def params(self) -> List[Tuple[str, str]]: + return [('::pw::stream::Reader&', 'message_stream')] + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message_stream, {self.field_cast()});' + ] + return lines + + class MessageProperty(ProtoMember): """Base class for a C++ property for a field in a protobuf message.""" @@ -375,7 +416,7 @@ class MessageProperty(ProtoMember): return False @staticmethod - def repeated_field_container(type_name: str, max_size: int) -> str: + def repeated_field_container(type_name: str, max_size: str) -> str: """Returns the container type used for repeated fields. Defaults to ::pw::Vector<type, max_size>. String fields use @@ -423,42 +464,36 @@ class MessageProperty(ProtoMember): def sub_table(self) -> str: # pylint: disable=no-self-use return '{}' - def struct_member(self, from_root: bool = False) -> Tuple[str, str]: - """Returns the structure member.""" + def struct_member_type(self, from_root: bool = False) -> str: + """Returns the structure member type.""" if self.use_callback(): return ( - f'{PROTOBUF_NAMESPACE}::Callback' - '<StreamEncoder, StreamDecoder>', - self.name(), + f'{PROTOBUF_NAMESPACE}::Callback<StreamEncoder, StreamDecoder>' ) # Optional fields are wrapped in std::optional if self.is_optional(): - return ( - 'std::optional<{}>'.format(self.type_name(from_root)), - self.name(), - ) + return 'std::optional<{}>'.format(self.type_name(from_root)) # Non-repeated fields have a member of just the type name. max_size = self.max_size() if max_size == 0: - return (self.type_name(from_root), self.name()) + return self.type_name(from_root) # Fixed size fields use std::array. if self.is_fixed_size(): - return ( - 'std::array<{}, {}>'.format( - self.type_name(from_root), max_size - ), - self.name(), + return 'std::array<{}, {}>'.format( + self.type_name(from_root), self.max_size_constant_name() ) # Otherwise prefer pw::Vector for repeated fields. - return ( - self.repeated_field_container(self.type_name(from_root), max_size), - self.name(), + return self.repeated_field_container( + self.type_name(from_root), self.max_size_constant_name() ) + def max_size_constant_name(self) -> str: + return f'k{self._field.name()}MaxSize' + def _varint_type_table_entry(self) -> str: if self.wire_type() == 'kVarint': return '{}::VarintType::{}'.format( @@ -574,6 +609,16 @@ class SubMessageDecoderMethod(ReadMethod): return False +class SubMessageFindMethod(FindMethod): + """Method which reads a proto submessage.""" + + def _result_type(self) -> str: + return '::pw::ConstByteSpan' + + def _find_fn(self) -> str: + return 'FindBytes' + + class SubMessageProperty(MessageProperty): """Property which contains a sub-message.""" @@ -716,6 +761,26 @@ class PackedDoubleReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedDouble' +class DoubleFindMethod(FindMethod): + """Method which reads a proto double value.""" + + def _result_type(self) -> str: + return 'double' + + def _find_fn(self) -> str: + return 'FindDouble' + + +class DoubleFindStreamMethod(FindStreamMethod): + """Method which reads a proto double value.""" + + def _result_type(self) -> str: + return 'double' + + def _find_fn(self) -> str: + return 'FindDouble' + + class DoubleProperty(MessageProperty): """Property which holds a proto double value.""" @@ -789,6 +854,26 @@ class PackedFloatReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedFloat' +class FloatFindMethod(FindMethod): + """Method which reads a proto float value.""" + + def _result_type(self) -> str: + return 'float' + + def _find_fn(self) -> str: + return 'FindFloat' + + +class FloatFindStreamMethod(FindStreamMethod): + """Method which reads a proto float value.""" + + def _result_type(self) -> str: + return 'float' + + def _find_fn(self) -> str: + return 'FindFloat' + + class FloatProperty(MessageProperty): """Property which holds a proto float value.""" @@ -862,6 +947,26 @@ class PackedInt32ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedInt32' +class Int32FindMethod(FindMethod): + """Method which reads a proto int32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindInt32' + + +class Int32FindStreamMethod(FindStreamMethod): + """Method which reads a proto int32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindInt32' + + class Int32Property(MessageProperty): """Property which holds a proto int32 value.""" @@ -938,6 +1043,26 @@ class PackedSint32ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedSint32' +class Sint32FindMethod(FindMethod): + """Method which reads a proto sint32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindSint32' + + +class Sint32FindStreamMethod(FindStreamMethod): + """Method which reads a proto sint32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindSint32' + + class Sint32Property(MessageProperty): """Property which holds a proto sint32 value.""" @@ -1014,6 +1139,26 @@ class PackedSfixed32ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedSfixed32' +class Sfixed32FindMethod(FindMethod): + """Method which reads a proto sfixed32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindSfixed32' + + +class Sfixed32FindStreamMethod(FindStreamMethod): + """Method which reads a proto sfixed32 value.""" + + def _result_type(self) -> str: + return 'int32_t' + + def _find_fn(self) -> str: + return 'FindSfixed32' + + class Sfixed32Property(MessageProperty): """Property which holds a proto sfixed32 value.""" @@ -1087,6 +1232,26 @@ class PackedInt64ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedInt64' +class Int64FindMethod(FindMethod): + """Method which reads a proto int64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindInt64' + + +class Int64FindStreamMethod(FindStreamMethod): + """Method which reads a proto int64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindInt64' + + class Int64Property(MessageProperty): """Property which holds a proto int64 value.""" @@ -1163,6 +1328,26 @@ class PackedSint64ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedSint64' +class Sint64FindMethod(FindMethod): + """Method which reads a proto sint64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindSint64' + + +class Sint64FindStreamMethod(FindStreamMethod): + """Method which reads a proto sint64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindSint64' + + class Sint64Property(MessageProperty): """Property which holds a proto sint64 value.""" @@ -1239,6 +1424,26 @@ class PackedSfixed64ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedSfixed64' +class Sfixed64FindMethod(FindMethod): + """Method which reads a proto sfixed64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindSfixed64' + + +class Sfixed64FindStreamMethod(FindStreamMethod): + """Method which reads a proto sfixed64 value.""" + + def _result_type(self) -> str: + return 'int64_t' + + def _find_fn(self) -> str: + return 'FindSfixed64' + + class Sfixed64Property(MessageProperty): """Property which holds a proto sfixed64 value.""" @@ -1312,6 +1517,26 @@ class PackedUint32ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedUint32' +class Uint32FindMethod(FindMethod): + """Method which finds a proto uint32 value.""" + + def _result_type(self) -> str: + return 'uint32_t' + + def _find_fn(self) -> str: + return 'FindUint32' + + +class Uint32FindStreamMethod(FindStreamMethod): + """Method which finds a proto uint32 value.""" + + def _result_type(self) -> str: + return 'uint32_t' + + def _find_fn(self) -> str: + return 'FindUint32' + + class Uint32Property(MessageProperty): """Property which holds a proto uint32 value.""" @@ -1388,6 +1613,26 @@ class PackedFixed32ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedFixed32' +class Fixed32FindMethod(FindMethod): + """Method which finds a proto fixed32 value.""" + + def _result_type(self) -> str: + return 'uint32_t' + + def _find_fn(self) -> str: + return 'FindFixed32' + + +class Fixed32FindStreamMethod(FindStreamMethod): + """Method which finds a proto fixed32 value.""" + + def _result_type(self) -> str: + return 'uint32_t' + + def _find_fn(self) -> str: + return 'FindFixed32' + + class Fixed32Property(MessageProperty): """Property which holds a proto fixed32 value.""" @@ -1461,6 +1706,26 @@ class PackedUint64ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedUint64' +class Uint64FindMethod(FindMethod): + """Method which finds a proto uint64 value.""" + + def _result_type(self) -> str: + return 'uint64_t' + + def _find_fn(self) -> str: + return 'FindUint64' + + +class Uint64FindStreamMethod(FindStreamMethod): + """Method which finds a proto uint64 value.""" + + def _result_type(self) -> str: + return 'uint64_t' + + def _find_fn(self) -> str: + return 'FindUint64' + + class Uint64Property(MessageProperty): """Property which holds a proto uint64 value.""" @@ -1537,6 +1802,26 @@ class PackedFixed64ReadVectorMethod(PackedReadVectorMethod): return 'ReadRepeatedFixed64' +class Fixed64FindMethod(FindMethod): + """Method which finds a proto fixed64 value.""" + + def _result_type(self) -> str: + return 'uint64_t' + + def _find_fn(self) -> str: + return 'FindFixed64' + + +class Fixed64FindStreamMethod(FindStreamMethod): + """Method which finds a proto fixed64 value.""" + + def _result_type(self) -> str: + return 'uint64_t' + + def _find_fn(self) -> str: + return 'FindFixed64' + + class Fixed64Property(MessageProperty): """Property which holds a proto fixed64 value.""" @@ -1600,6 +1885,26 @@ class PackedBoolReadMethod(PackedReadMethod): return 'ReadPackedBool' +class BoolFindMethod(FindMethod): + """Method which finds a proto bool value.""" + + def _result_type(self) -> str: + return 'bool' + + def _find_fn(self) -> str: + return 'FindBool' + + +class BoolFindStreamMethod(FindStreamMethod): + """Method which finds a proto bool value.""" + + def _result_type(self) -> str: + return 'bool' + + def _find_fn(self) -> str: + return 'FindBool' + + class BoolProperty(MessageProperty): """Property which holds a proto bool value.""" @@ -1639,6 +1944,40 @@ class BytesReadMethod(ReadMethod): return 'ReadBytes' +class BytesFindMethod(FindMethod): + """Method which reads a proto bytes value.""" + + def _result_type(self) -> str: + return '::pw::ConstByteSpan' + + def _find_fn(self) -> str: + return 'FindBytes' + + +class BytesFindStreamMethod(FindStreamMethod): + """Method which reads a proto bytes value.""" + + def return_type(self, from_root: bool = False) -> str: + return '::pw::StatusWithSize' + + def params(self) -> List[Tuple[str, str]]: + return [ + ('::pw::stream::Reader&', 'message_stream'), + ('::pw::ByteSpan', 'out'), + ] + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message_stream, {self.field_cast()}, out);' + ] + return lines + + def _find_fn(self) -> str: + return 'FindBytes' + + class BytesProperty(MessageProperty): """Property which holds a proto bytes value.""" @@ -1676,7 +2015,7 @@ class BytesProperty(MessageProperty): def _size_length(self) -> Optional[str]: if self.use_callback(): return None - return f'{self.max_size()}' + return self.max_size_constant_name() class StringLenWriteMethod(WriteMethod): @@ -1712,6 +2051,64 @@ class StringReadMethod(ReadMethod): return 'ReadString' +class StringFindMethod(FindMethod): + """Method which reads a proto string value.""" + + def _result_type(self) -> str: + return 'std::string_view' + + def _find_fn(self) -> str: + return 'FindString' + + +class StringFindStreamMethod(FindStreamMethod): + """Method which reads a proto string value.""" + + def return_type(self, from_root: bool = False) -> str: + return '::pw::StatusWithSize' + + def params(self) -> List[Tuple[str, str]]: + return [ + ('::pw::stream::Reader&', 'message_stream'), + ('::pw::span<char>', 'out'), + ] + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message_stream, {self.field_cast()}, out);' + ] + return lines + + def _find_fn(self) -> str: + return 'FindString' + + +class StringFindStreamMethodInlineString(FindStreamMethod): + """Method which reads a proto string value to an InlineString.""" + + def return_type(self, from_root: bool = False) -> str: + return '::pw::StatusWithSize' + + def params(self) -> List[Tuple[str, str]]: + return [ + ('::pw::stream::Reader&', 'message_stream'), + ('::pw::InlineString<>&', 'out'), + ] + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message_stream, {self.field_cast()}, out);' + ] + return lines + + def _find_fn(self) -> str: + return 'FindString' + + class StringProperty(MessageProperty): """Property which holds a proto string value.""" @@ -1739,7 +2136,7 @@ class StringProperty(MessageProperty): return True @staticmethod - def repeated_field_container(type_name: str, max_size: int) -> str: + def repeated_field_container(type_name: str, max_size: str) -> str: return f'::pw::InlineBasicString<{type_name}, {max_size}>' def _size_fn(self) -> str: @@ -1751,7 +2148,7 @@ class StringProperty(MessageProperty): def _size_length(self) -> Optional[str]: if self.use_callback(): return None - return f'{self.max_size()}' + return self.max_size_constant_name() class EnumWriteMethod(WriteMethod): @@ -1863,6 +2260,52 @@ class PackedEnumReadVectorMethod(PackedReadVectorMethod): ] +class EnumFindMethod(FindMethod): + """Method which finds a proto enum value.""" + + def _result_type(self) -> str: + return self._relative_type_namespace() + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + '::pw::Result<uint32_t> result = ' + f'{PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message, {self.field_cast()});', + 'if (!result.ok()) {', + ' return result.status();', + '}', + f'return static_cast<{self._result_type()}>(result.value());', + ] + return lines + + def _find_fn(self) -> str: + return 'FindUint32' + + +class EnumFindStreamMethod(FindStreamMethod): + """Method which finds a proto enum value.""" + + def _result_type(self) -> str: + return self._relative_type_namespace() + + def body(self) -> List[str]: + lines: List[str] = [] + lines += [ + '::pw::Result<uint32_t> result = ' + f'{PROTOBUF_NAMESPACE}::{self._find_fn()}' + f'(message_stream, {self.field_cast()});', + 'if (!result.ok()) {', + ' return result.status();', + '}', + f'return static_cast<{self._result_type()}>(result.value());', + ] + return lines + + def _find_fn(self) -> str: + return 'FindUint32' + + class EnumProperty(MessageProperty): """Property which holds a proto enum value.""" @@ -2040,27 +2483,121 @@ PROTO_FIELD_READ_METHODS: Dict[int, List] = { ], } -PROTO_FIELD_PROPERTIES: Dict[int, List] = { - descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [DoubleProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [FloatProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [Int32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [Sint32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [Sfixed32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [Int64Property], - descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [Sint64Property], - descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [Sfixed32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [Uint32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [Fixed32Property], - descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [Uint64Property], - descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [Fixed64Property], - descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [BoolProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [StringProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageProperty], - descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumProperty], +PROTO_FIELD_FIND_METHODS: Dict[int, List] = { + descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [ + DoubleFindMethod, + DoubleFindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [ + FloatFindMethod, + FloatFindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [ + Int32FindMethod, + Int32FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [ + Sint32FindMethod, + Sint32FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [ + Sfixed32FindMethod, + Sfixed32FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [ + Int64FindMethod, + Int64FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [ + Sint64FindMethod, + Sint64FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [ + Sfixed64FindMethod, + Sfixed64FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [ + Uint32FindMethod, + Uint32FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [ + Fixed32FindMethod, + Fixed32FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [ + Uint64FindMethod, + Uint64FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [ + Fixed64FindMethod, + Fixed64FindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [ + BoolFindMethod, + BoolFindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [ + BytesFindMethod, + BytesFindStreamMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [ + StringFindMethod, + StringFindStreamMethod, + StringFindStreamMethodInlineString, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [ + SubMessageFindMethod, + ], + descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [ + EnumFindMethod, + EnumFindStreamMethod, + ], +} + +PROTO_FIELD_PROPERTIES: Dict[int, Type[MessageProperty]] = { + descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: DoubleProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: FloatProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_INT32: Int32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: Sint32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: Sfixed32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_INT64: Int64Property, + descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: Sint64Property, + descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: Sfixed32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: Uint32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: Fixed32Property, + descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: Uint64Property, + descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: Fixed64Property, + descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: BoolProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: BytesProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_STRING: StringProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: SubMessageProperty, + descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: EnumProperty, } +def proto_message_field_props( + message: ProtoMessage, + root: ProtoNode, +) -> Iterable[MessageProperty]: + """Yields a MessageProperty for each field in a ProtoMessage. + + Only properties which should_appear() is True are returned. + + Args: + message: The ProtoMessage whose fields are iterated. + root: The root ProtoNode of the tree. + + Yields: + An appropriately-typed MessageProperty object for each field + in the message, to which the property refers. + """ + for field in message.fields(): + property_class = PROTO_FIELD_PROPERTIES[field.type()] + prop = property_class(field, message, root) + if prop.should_appear(): + yield prop + + def proto_field_methods(class_type: ClassType, field_type: int) -> List: return ( PROTO_FIELD_WRITE_METHODS[field_type] @@ -2316,31 +2853,41 @@ def generate_to_string_for_enum( def forward_declare( - node: ProtoMessage, + message: ProtoMessage, root: ProtoNode, output: OutputFile, exclude_legacy_snake_case_field_name_enums: bool, ) -> None: """Generates code forward-declaring entities in a message's namespace.""" - namespace = node.cpp_namespace(root=root) + namespace = message.cpp_namespace(root=root) output.write_line() output.write_line(f'namespace {namespace} {{') # Define an enum defining each of the message's fields and their numbers. output.write_line('enum class Fields : uint32_t {') with output.indent(): - for field in node.fields(): + for field in message.fields(): output.write_line(f'{field.enum_name()} = {field.number()},') # Migration support from SNAKE_CASE to kConstantCase. if not exclude_legacy_snake_case_field_name_enums: - for field in node.fields(): + for field in message.fields(): output.write_line( f'{field.legacy_enum_name()} = {field.number()},' ) output.write_line('};') + # Define constants for fixed-size fields. + output.write_line() + for prop in proto_message_field_props(message, root): + max_size = prop.max_size() + if max_size: + output.write_line( + f'static constexpr size_t {prop.max_size_constant_name()} ' + f'= {max_size};' + ) + # Declare the message's message struct. output.write_line() output.write_line('struct Message;') @@ -2355,14 +2902,14 @@ def forward_declare( output.write_line('class StreamDecoder;') # Declare the message's enums. - for child in node.children(): + for child in message.children(): if child.type() == ProtoNode.Type.ENUM: output.write_line() - generate_code_for_enum(cast(ProtoEnum, child), node, output) + generate_code_for_enum(cast(ProtoEnum, child), message, output) output.write_line() - generate_function_for_enum(cast(ProtoEnum, child), node, output) + generate_function_for_enum(cast(ProtoEnum, child), message, output) output.write_line() - generate_to_string_for_enum(cast(ProtoEnum, child), node, output) + generate_to_string_for_enum(cast(ProtoEnum, child), message, output) output.write_line(f'}} // namespace {namespace}') @@ -2378,17 +2925,13 @@ def generate_struct_for_message( # Generate members for each of the message's fields. with output.indent(): cmp: List[str] = [] - for field in message.fields(): - for property_class in PROTO_FIELD_PROPERTIES[field.type()]: - prop = property_class(field, message, root) - if not prop.should_appear(): - continue - - (type_name, name) = prop.struct_member() - output.write_line(f'{type_name} {name};') + for prop in proto_message_field_props(message, root): + type_name = prop.struct_member_type() + name = prop.name() + output.write_line(f'{type_name} {name};') - if not prop.use_callback(): - cmp.append(f'{name} == other.{name}') + if not prop.use_callback(): + cmp.append(f'{name} == other.{name}') # Equality operator output.write_line() @@ -2417,12 +2960,7 @@ def generate_table_for_message( namespace = message.cpp_namespace(root=root) output.write_line(f'namespace {namespace} {{') - properties = [] - for field in message.fields(): - for property_class in PROTO_FIELD_PROPERTIES[field.type()]: - prop = property_class(field, message, root) - if prop.should_appear(): - properties.append(prop) + properties = list(proto_message_field_props(message, root)) output.write_line('PW_MODIFY_DIAGNOSTICS_PUSH();') output.write_line('PW_MODIFY_DIAGNOSTIC(ignored, "-Winvalid-offsetof");') @@ -2469,7 +3007,7 @@ def generate_table_for_message( ) member_list = ', '.join( - [f'message.{prop.struct_member()[1]}' for prop in properties] + [f'message.{prop.name()}' for prop in properties] ) # Generate std::tuple for Message fields. @@ -2505,15 +3043,10 @@ def generate_sizes_for_message( property_sizes: List[str] = [] scratch_sizes: List[str] = [] - for field in message.fields(): - for property_class in PROTO_FIELD_PROPERTIES[field.type()]: - prop = property_class(field, message, root) - if not prop.should_appear(): - continue - - property_sizes.append(prop.max_encoded_size()) - if prop.include_in_scratch_size(): - scratch_sizes.append(prop.max_encoded_size()) + for prop in proto_message_field_props(message, root): + property_sizes.append(prop.max_encoded_size()) + if prop.include_in_scratch_size(): + scratch_sizes.append(prop.max_encoded_size()) output.write_line('inline constexpr size_t kMaxEncodedSizeBytes =') with output.indent(): @@ -2540,19 +3073,53 @@ def generate_sizes_for_message( output.write_line(f'}} // namespace {namespace}') -def generate_is_trivially_comparable_specialization( +def generate_find_functions_for_message( message: ProtoMessage, root: ProtoNode, output: OutputFile ) -> None: - is_trivially_comparable = True + """Creates C++ constants for the encoded sizes of a protobuf message.""" + assert message.type() == ProtoNode.Type.MESSAGE + + namespace = message.cpp_namespace(root=root) + output.write_line(f'namespace {namespace} {{') + for field in message.fields(): - for property_class in PROTO_FIELD_PROPERTIES[field.type()]: - prop = property_class(field, message, root) - if not prop.should_appear(): - continue + if field.is_repeated(): + # Find methods don't account for repeated field semantics, so + # ignore them to avoid confusion. + continue + + try: + methods = PROTO_FIELD_FIND_METHODS[field.type()] + except KeyError: + continue + + for cls in methods: + method = cls(field, message, root, '') + method_signature = ( + f'inline {method.return_type()} ' + f'{method.name()}({method.param_string()})' + ) + + output.write_line() + output.write_line(f'{method_signature} {{') + + with output.indent(): + for line in method.body(): + output.write_line(line) + + output.write_line('}') + + output.write_line(f'}} // namespace {namespace}') + - if prop.use_callback(): - is_trivially_comparable = False - break +def generate_is_trivially_comparable_specialization( + message: ProtoMessage, root: ProtoNode, output: OutputFile +) -> None: + is_trivially_comparable = True + for prop in proto_message_field_props(message, root): + if prop.use_callback(): + is_trivially_comparable = False + break qualified_message = f'{message.cpp_namespace()}::Message' @@ -2626,6 +3193,7 @@ def generate_code_for_package( output.write_line('#include "pw_containers/vector.h"') output.write_line('#include "pw_preprocessor/compiler.h"') output.write_line('#include "pw_protobuf/encoder.h"') + output.write_line('#include "pw_protobuf/find.h"') output.write_line('#include "pw_protobuf/internal/codegen.h"') output.write_line('#include "pw_protobuf/serialized_size.h"') output.write_line('#include "pw_protobuf/stream_decoder.h"') @@ -2675,6 +3243,8 @@ def generate_code_for_package( output.write_line() generate_sizes_for_message(message, package, output) output.write_line() + generate_find_functions_for_message(message, package, output) + output.write_line() generate_class_for_message( message, package, output, ClassType.STREAMING_ENCODER ) @@ -2720,7 +3290,7 @@ def generate_code_for_package( output.write_line(f'using namespace ::{package.cpp_namespace()};') output.write_line(f'}} // namespace {legacy_namespace}') - # TODO(b/250945489) Remove this if possible + # TODO: b/250945489 - Remove this if possible output.write_line() output.write_line( '// Codegen implementation detail; do not use this namespace!' |