summaryrefslogtreecommitdiff
path: root/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/schema/ProtoBufSchemaGenerator.kt
blob: e54370fbaaff5f1210aa11896a28c3c412a4038a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
@file:OptIn(ExperimentalSerializationApi::class)

package kotlinx.serialization.protobuf.schema

import kotlinx.serialization.*
import kotlinx.serialization.builtins.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.protobuf.*
import kotlinx.serialization.protobuf.internal.*

/**
 * Experimental generator of ProtoBuf schema that is compatible with [serializable][Serializable] Kotlin classes
 * and data encoded and decoded by [ProtoBuf] format.
 *
 * The schema is generated based on provided [SerialDescriptor] and is compatible with proto2 schema definition.
 * An arbitrary Kotlin class represent much wider object domain than the ProtoBuf specification, thus schema generator
 * has the following list of restrictions:
 *
 *  * Serial name of the class and all its fields should be a valid Proto [identifier](https://developers.google.com/protocol-buffers/docs/reference/proto2-spec)
 *  * Nullable values are allowed only for Kotlin [nullable][SerialDescriptor.isNullable] types, but not [optional][SerialDescriptor.isElementOptional]
 *    in order to properly distinguish "default" and "absent" values.
 *  * The name of the type without the package directive uniquely identifies the proto message type and two or more fields with the same serial name
 *    are considered to have the same type. Schema generator allows to specify a separate package directive for the pack of classes in order
 *    to mitigate this limitation.
 *  * Nested collections, e.g. `List<List<Int>>` are represented using the artificial wrapper message in order to distinguish
 *    repeated fields boundaries.
 *  * Default Kotlin values are not representable in proto schema. A special commentary is generated for properties with default values.
 *  * Empty nullable collections are not supported by the generated schema and will be prohibited in [ProtoBuf] as well
 *    due to their ambiguous nature.
 *
 * Temporary restrictions:
 *  * [Contextual] data is represented as as `bytes` type
 *  * [Polymorphic] data is represented as a artificial `KotlinxSerializationPolymorphic` message.
 *
 * Other types are mapped according to their specification: primitives as primitives, lists as 'repeated' fields and
 * maps as 'repeated' map entries.
 *
 * The name of messages and enums is extracted from [SerialDescriptor.serialName] in [SerialDescriptor] without the package directive,
 * as substring after the last dot character, the `'?'` character is also removed if it is present at the end of the string.
 */
@ExperimentalSerializationApi
public object ProtoBufSchemaGenerator {

    /**
     * Generate text of protocol buffers schema version 2 for the given [rootDescriptor].
     * The resulting schema will contain all types referred by [rootDescriptor].
     *
     * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
     * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but must be started only by a letter and
     * not finished by a dot.
     *
     * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
     * should be the same format as [packageName].
     *
     * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
     */
    @ExperimentalSerializationApi
    public fun generateSchemaText(
        rootDescriptor: SerialDescriptor,
        packageName: String? = null,
        options: Map<String, String> = emptyMap()
    ): String = generateSchemaText(listOf(rootDescriptor), packageName, options)

    /**
     * Generate text of protocol buffers schema version 2 for the given serializable [descriptors].
     * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
     * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but started only from a letter and
     * not finished by dot.
     *
     * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
     * should be the same format as [packageName].
     *
     * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
     */
    @ExperimentalSerializationApi
    public fun generateSchemaText(
        descriptors: List<SerialDescriptor>,
        packageName: String? = null,
        options: Map<String, String> = emptyMap()
    ): String {
        packageName?.let { p -> p.checkIsValidFullIdentifier { "Incorrect protobuf package name '$it'" } }
        checkDoubles(descriptors)
        val builder = StringBuilder()
        builder.generateProto2SchemaText(descriptors, packageName, options)
        return builder.toString()
    }

    private fun checkDoubles(descriptors: List<SerialDescriptor>) {
        val rootTypesNames = mutableSetOf<String>()
        val duplicates = mutableListOf<String>()

        descriptors.map { it.messageOrEnumName }.forEach {
            if (!rootTypesNames.add(it)) {
                duplicates += it
            }
        }
        if (duplicates.isNotEmpty()) {
            throw IllegalArgumentException("Serial names of the following types are duplicated: $duplicates")
        }
    }

    private fun StringBuilder.generateProto2SchemaText(
        descriptors: List<SerialDescriptor>,
        packageName: String?,
        options: Map<String, String>
    ) {
        appendLine("""syntax = "proto2";""").appendLine()

        packageName?.let { append("package ").append(it).appendLine(';') }

        for ((optionName, optionValue) in options) {
            val safeOptionName = removeLineBreaks(optionName)
            val safeOptionValue = removeLineBreaks(optionValue)
            safeOptionName.checkIsValidFullIdentifier { "Invalid option name '$it'" }
            append("option ").append(safeOptionName).append(" = \"").append(safeOptionValue).appendLine("\";")
        }

        val generatedTypes = mutableSetOf<String>()
        val queue = ArrayDeque<TypeDefinition>()
        descriptors.map { TypeDefinition(it) }.forEach { queue.add(it) }

        while (queue.isNotEmpty()) {
            val type = queue.removeFirst()
            val descriptor = type.descriptor
            val name = descriptor.messageOrEnumName
            if (!generatedTypes.add(name)) {
                continue
            }

            appendLine()
            when {
                descriptor.isProtobufMessage -> queue.addAll(generateMessage(type))
                descriptor.isProtobufEnum -> generateEnum(type)
                else -> throw IllegalStateException(
                    "Unrecognized custom type with serial name "
                            + "'${descriptor.serialName}' and kind '${descriptor.kind}'"
                )
            }
        }
    }

    private fun StringBuilder.generateMessage(messageType: TypeDefinition): List<TypeDefinition> {
        val messageDescriptor = messageType.descriptor
        val messageName: String
        if (messageType.isSynthetic) {
            append("// This message was generated to support ").append(messageType.ability)
                .appendLine(" and does not present in Kotlin.")

            messageName = messageDescriptor.serialName
            if (messageType.containingMessageName != null) {
                append("// Containing message '").append(messageType.containingMessageName).append("', field '")
                    .append(messageType.fieldName).appendLine('\'')
            }
        } else {
            messageName = messageDescriptor.messageOrEnumName
            messageName.checkIsValidIdentifier {
                "Invalid name for the message in protobuf schema '$messageName'. " +
                        "Serial name of the class '${messageDescriptor.serialName}'"
            }
            val safeSerialName = removeLineBreaks(messageDescriptor.serialName)
            if (safeSerialName != messageName) {
                append("// serial name '").append(safeSerialName).appendLine('\'')
            }
        }

        append("message ").append(messageName).appendLine(" {")

        val usedNumbers: MutableSet<Int> = mutableSetOf()
        val nestedTypes = mutableListOf<TypeDefinition>()
        for (index in 0 until messageDescriptor.elementsCount) {
            val fieldName = messageDescriptor.getElementName(index)
            fieldName.checkIsValidIdentifier {
                "Invalid name of the field '$fieldName' in message '$messageName' for class with serial " +
                        "name '${messageDescriptor.serialName}'"
            }

            val fieldDescriptor = messageDescriptor.getElementDescriptor(index)

            val isList = fieldDescriptor.isProtobufRepeated

            nestedTypes += when {
                fieldDescriptor.isProtobufNamedType -> generateNamedType(messageType, index)
                isList -> generateListType(messageType, index)
                fieldDescriptor.isProtobufMap -> generateMapType(messageType, index)
                else -> throw IllegalStateException(
                    "Unprocessed message field type with serial name " +
                            "'${fieldDescriptor.serialName}' and kind '${fieldDescriptor.kind}'"
                )
            }


            val annotations = messageDescriptor.getElementAnnotations(index)
            val number = annotations.filterIsInstance<ProtoNumber>().singleOrNull()?.number ?: index + 1
            if (!usedNumbers.add(number)) {
                throw IllegalArgumentException("Field number $number is repeated in the class with serial name ${messageDescriptor.serialName}")
            }

            append(' ').append(fieldName).append(" = ").append(number)

            val isPackRequested = annotations.filterIsInstance<ProtoPacked>().singleOrNull() != null

            when {
                !isPackRequested ||
                !isList || // ignore as packed only meaningful on repeated types
                !fieldDescriptor.getElementDescriptor(0).isPackable // Ignore if the type is not allowed to be packed
                     -> appendLine(';')
                else -> appendLine(" [packed=true];")
            }
        }
        appendLine('}')

        return nestedTypes
    }

    private fun StringBuilder.generateNamedType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
        val messageDescriptor = messageType.descriptor

        val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
        val nestedTypes: List<TypeDefinition>
        val typeName: String = when {
            messageDescriptor.isSealedPolymorphic && index == 1 -> {
                appendLine("  // decoded as message with one of these types:")
                nestedTypes = fieldDescriptor.elementDescriptors.map { TypeDefinition(it) }.toList()
                nestedTypes.forEachIndexed { _, childType ->
                    append("  //   message ").append(childType.descriptor.messageOrEnumName).append(", serial name '")
                        .append(removeLineBreaks(childType.descriptor.serialName)).appendLine('\'')
                }
                fieldDescriptor.scalarTypeName()
            }
            fieldDescriptor.isProtobufScalar -> {
                nestedTypes = emptyList()
                fieldDescriptor.scalarTypeName(messageDescriptor.getElementAnnotations(index))
            }
            fieldDescriptor.isOpenPolymorphic -> {
                nestedTypes = listOf(SyntheticPolymorphicType)
                SyntheticPolymorphicType.descriptor.serialName
            }
            else -> {
                // enum or regular message
                nestedTypes = listOf(TypeDefinition(fieldDescriptor))
                fieldDescriptor.messageOrEnumName
            }
        }

        if (messageDescriptor.isElementOptional(index)) {
            appendLine("  // WARNING: a default value decoded when value is missing")
        }
        val optional = fieldDescriptor.isNullable || messageDescriptor.isElementOptional(index)

        append("  ").append(if (optional) "optional " else "required ").append(typeName)

        return nestedTypes
    }

    private fun StringBuilder.generateMapType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
        val messageDescriptor = messageType.descriptor
        val mapDescriptor = messageDescriptor.getElementDescriptor(index)
        val originalMapValueDescriptor = mapDescriptor.getElementDescriptor(1)
        val valueType = if (originalMapValueDescriptor.isProtobufCollection) {
            createNestedCollectionType(messageType, index, originalMapValueDescriptor, "nested collection in map value")
        } else {
            TypeDefinition(originalMapValueDescriptor)
        }
        val valueDescriptor = valueType.descriptor

        if (originalMapValueDescriptor.isNullable) {
            appendLine("  // WARNING: nullable map values can not be represented in protobuf")
        }
        generateCollectionAbsenceComment(messageDescriptor, mapDescriptor, index)

        val keyTypeName = mapDescriptor.getElementDescriptor(0).scalarTypeName(mapDescriptor.getElementAnnotations(0))
        val valueTypeName = valueDescriptor.protobufTypeName(mapDescriptor.getElementAnnotations(1))
        append("  map<").append(keyTypeName).append(", ").append(valueTypeName).append(">")

        return if (valueDescriptor.isProtobufMessageOrEnum) {
            listOf(valueType)
        } else {
            emptyList()
        }
    }

    private fun StringBuilder.generateListType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
        val messageDescriptor = messageType.descriptor
        val collectionDescriptor = messageDescriptor.getElementDescriptor(index)
        val originalElementDescriptor = collectionDescriptor.getElementDescriptor(0)
        val elementType = if (collectionDescriptor.kind == StructureKind.LIST) {
            if (originalElementDescriptor.isProtobufCollection) {
                createNestedCollectionType(messageType, index, originalElementDescriptor, "nested collection in list")
            } else {
                TypeDefinition(originalElementDescriptor)
            }
        } else {
            createLegacyMapType(messageType, index, "legacy map")
        }

        val elementDescriptor = elementType.descriptor

        if (elementDescriptor.isNullable) {
            appendLine("  // WARNING: nullable elements of collections can not be represented in protobuf")
        }
        generateCollectionAbsenceComment(messageDescriptor, collectionDescriptor, index)

        val typeName = elementDescriptor.protobufTypeName(messageDescriptor.getElementAnnotations(index))
        append("  repeated ").append(typeName)

        return if (elementDescriptor.isProtobufMessageOrEnum) {
            listOf(elementType)
        } else {
            emptyList()
        }
    }

    private fun StringBuilder.generateEnum(enumType: TypeDefinition) {
        val enumDescriptor = enumType.descriptor
        val enumName = enumDescriptor.messageOrEnumName
        enumName.checkIsValidIdentifier {
            "Invalid name for the enum in protobuf schema '$enumName'. Serial name of the enum " +
                    "class '${enumDescriptor.serialName}'"
        }
        val safeSerialName = removeLineBreaks(enumDescriptor.serialName)
        if (safeSerialName != enumName) {
            append("// serial name '").append(enumName).appendLine('\'')
        }

        append("enum ").append(enumName).appendLine(" {")

        enumDescriptor.elementDescriptors.forEachIndexed { number, element ->
            val elementName = element.protobufEnumElementName
            elementName.checkIsValidIdentifier {
                "The enum element name '$elementName' is invalid in the " +
                        "protobuf schema. Serial name of the enum class '${enumDescriptor.serialName}'"
            }
            append("  ").append(elementName).append(" = ").append(number).appendLine(';')
        }
        appendLine('}')
    }

    private val SerialDescriptor.isOpenPolymorphic: Boolean
        get() = kind == PolymorphicKind.OPEN

    private val SerialDescriptor.isSealedPolymorphic: Boolean
        get() = kind == PolymorphicKind.SEALED

    private val SerialDescriptor.isProtobufNamedType: Boolean
        get() = isProtobufMessageOrEnum || isProtobufScalar

    private val SerialDescriptor.isProtobufScalar: Boolean
        get() = (kind is PrimitiveKind)
                || (kind is StructureKind.LIST && getElementDescriptor(0).kind === PrimitiveKind.BYTE)
                || kind == SerialKind.CONTEXTUAL

    private val SerialDescriptor.isProtobufMessageOrEnum: Boolean
        get() = isProtobufMessage || isProtobufEnum

    private val SerialDescriptor.isProtobufMessage: Boolean
        get() = kind == StructureKind.CLASS || kind == StructureKind.OBJECT || kind == PolymorphicKind.SEALED || kind == PolymorphicKind.OPEN

    private val SerialDescriptor.isProtobufCollection: Boolean
        get() = isProtobufRepeated || isProtobufMap

    private val SerialDescriptor.isProtobufRepeated: Boolean
        get() = (kind == StructureKind.LIST && getElementDescriptor(0).kind != PrimitiveKind.BYTE)
                || (kind == StructureKind.MAP && !getElementDescriptor(0).isValidMapKey)

    private val SerialDescriptor.isProtobufMap: Boolean
        get() = kind == StructureKind.MAP && getElementDescriptor(0).isValidMapKey

    private val SerialDescriptor.isProtobufEnum: Boolean
        get() = kind == SerialKind.ENUM

    private val SerialDescriptor.isValidMapKey: Boolean
        get() = kind == PrimitiveKind.INT || kind == PrimitiveKind.LONG || kind == PrimitiveKind.BOOLEAN || kind == PrimitiveKind.STRING


    private val SerialDescriptor.messageOrEnumName: String
        get() = (serialName.substringAfterLast('.', serialName)).removeSuffix("?")

    private fun SerialDescriptor.protobufTypeName(annotations: List<Annotation> = emptyList()): String {
        return if (isProtobufScalar) {
            scalarTypeName(annotations)
        } else {
            messageOrEnumName
        }
    }

    private val SerialDescriptor.protobufEnumElementName: String
        get() = serialName.substringAfterLast('.', serialName)

    private fun SerialDescriptor.scalarTypeName(annotations: List<Annotation> = emptyList()): String {
        val integerType = annotations.filterIsInstance<ProtoType>().firstOrNull()?.type ?: ProtoIntegerType.DEFAULT

        if (kind == SerialKind.CONTEXTUAL) {
            return "bytes"
        }

        if (kind is StructureKind.LIST && getElementDescriptor(0).kind == PrimitiveKind.BYTE) {
            return "bytes"
        }

        return when (kind as PrimitiveKind) {
            PrimitiveKind.BOOLEAN -> "bool"
            PrimitiveKind.BYTE, PrimitiveKind.CHAR, PrimitiveKind.SHORT, PrimitiveKind.INT ->
                when (integerType) {
                    ProtoIntegerType.DEFAULT -> "int32"
                    ProtoIntegerType.SIGNED -> "sint32"
                    ProtoIntegerType.FIXED -> "fixed32"
                }
            PrimitiveKind.LONG ->
                when (integerType) {
                    ProtoIntegerType.DEFAULT -> "int64"
                    ProtoIntegerType.SIGNED -> "sint64"
                    ProtoIntegerType.FIXED -> "fixed64"
                }
            PrimitiveKind.FLOAT -> "float"
            PrimitiveKind.DOUBLE -> "double"
            PrimitiveKind.STRING -> "string"
        }
    }

    private data class TypeDefinition(
        val descriptor: SerialDescriptor,
        val isSynthetic: Boolean = false,
        val ability: String? = null,
        val containingMessageName: String? = null,
        val fieldName: String? = null
    )

    private val SyntheticPolymorphicType = TypeDefinition(
        buildClassSerialDescriptor("KotlinxSerializationPolymorphic") {
            element("type", PrimitiveSerialDescriptor("typeDescriptor", PrimitiveKind.STRING))
            element("value", buildSerialDescriptor("valueDescriptor", StructureKind.LIST) {
                element("0", Byte.serializer().descriptor)
            })
        },
        true,
        "polymorphic types"
    )

    private class NotNullSerialDescriptor(val original: SerialDescriptor) : SerialDescriptor by original {
        override val isNullable = false
    }

    private val SerialDescriptor.notNull get() = NotNullSerialDescriptor(this)

    private fun StringBuilder.generateCollectionAbsenceComment(
        messageDescriptor: SerialDescriptor,
        collectionDescriptor: SerialDescriptor,
        index: Int
    ) {
        if (!collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
            appendLine("  // WARNING: a default value decoded when value is missing")
        } else if (collectionDescriptor.isNullable && !messageDescriptor.isElementOptional(index)) {
            appendLine("  // WARNING: an empty collection decoded when a value is missing")
        } else if (collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
            appendLine("  // WARNING: a default value decoded when value is missing")
        }
    }

    private fun createLegacyMapType(
        messageType: TypeDefinition,
        index: Int,
        description: String
    ): TypeDefinition {
        val messageDescriptor = messageType.descriptor
        val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
        val fieldName = messageDescriptor.getElementName(index)
        val messageName = messageDescriptor.messageOrEnumName

        val wrapperName = "${messageName}_${fieldName}"
        val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
            element("key", fieldDescriptor.getElementDescriptor(0).notNull)
            element("value", fieldDescriptor.getElementDescriptor(1).notNull)
        }

        return TypeDefinition(
            wrapperDescriptor,
            true,
            description,
            messageType.containingMessageName ?: messageName,
            messageType.fieldName ?: fieldName
        )
    }

    private fun createNestedCollectionType(
        messageType: TypeDefinition,
        index: Int,
        elementDescriptor: SerialDescriptor,
        description: String
    ): TypeDefinition {
        val messageDescriptor = messageType.descriptor
        val fieldName = messageDescriptor.getElementName(index)
        val messageName = messageDescriptor.messageOrEnumName

        val wrapperName = "${messageName}_${fieldName}"
        val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
            element("value", elementDescriptor.notNull)
        }

        return TypeDefinition(
            wrapperDescriptor,
            true,
            description,
            messageType.containingMessageName ?: messageName,
            messageType.fieldName ?: fieldName
        )
    }

    private fun removeLineBreaks(text: String): String {
        return text.replace('\n', ' ').replace('\r', ' ')
    }

    private val IDENTIFIER_REGEX = Regex("[A-Za-z][A-Za-z0-9_]*")

    private fun String.checkIsValidFullIdentifier(messageSupplier: (String) -> String) {
        if (split('.').any { !it.matches(IDENTIFIER_REGEX) }) {
            throw IllegalArgumentException(messageSupplier.invoke(this))
        }
    }

    private fun String.checkIsValidIdentifier(messageSupplier: () -> String) {
        if (!matches(IDENTIFIER_REGEX)) {
            throw IllegalArgumentException(messageSupplier.invoke())
        }
    }
}