diff options
author | Vsevolod Tolstopyatov <qwwdfsad@gmail.com> | 2022-06-24 16:36:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-24 17:36:32 +0300 |
commit | 93a06df4bee36dc43b7d906ca395b0ac0d3229f3 (patch) | |
tree | 537cb66847cd8733779cbe63c50cab155f71d4a3 | |
parent | bb18d6243a524a1512da48f897804564e147af1f (diff) | |
download | kotlinx.serialization-93a06df4bee36dc43b7d906ca395b0ac0d3229f3.tar.gz |
Do not use tree-based decoding for fast-path polymorphism (#1919)
Do not use tree-based decoding for fast-path polymorphism and try to optimistically read it as very first key and then silently skip
Fixes #1839
15 files changed, 233 insertions, 28 deletions
diff --git a/benchmark/build.gradle b/benchmark/build.gradle index 8e0e4927..0935e5a3 100644 --- a/benchmark/build.gradle +++ b/benchmark/build.gradle @@ -6,13 +6,12 @@ apply plugin: 'java' apply plugin: 'kotlin' apply plugin: 'kotlinx-serialization' apply plugin: 'idea' -apply plugin: 'net.ltgt.apt' apply plugin: 'com.github.johnrengelman.shadow' -apply plugin: 'me.champeau.gradle.jmh' +apply plugin: 'me.champeau.jmh' sourceCompatibility = 1.8 targetCompatibility = 1.8 -jmh.jmhVersion = 1.22 +jmh.jmhVersion = "1.22" jmhJar { baseName 'benchmarks' diff --git a/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt new file mode 100644 index 00000000..b272bae6 --- /dev/null +++ b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt @@ -0,0 +1,54 @@ +package kotlinx.benchmarks.json + +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlinx.serialization.modules.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.* + +@Warmup(iterations = 7, time = 1) +@Measurement(iterations = 5, time = 1) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(1) +open class PolymorphismOverheadBenchmark { + + @Serializable + @JsonClassDiscriminator("poly") + data class PolymorphicWrapper(val i: @Polymorphic Poly, val i2: Impl) // amortize the cost a bit + + @Serializable + data class BaseWrapper(val i: Impl, val i2: Impl) + + @JsonClassDiscriminator("poly") + interface Poly + + @Serializable + @JsonClassDiscriminator("poly") + class Impl(val a: Int, val b: String) : Poly + + private val impl = Impl(239, "average_size_string") + private val module = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl.serializer()) + } + } + + private val json = Json { serializersModule = module } + private val implString = json.encodeToString(impl) + private val polyString = json.encodeToString<Poly>(impl) + private val serializer = serializer<Poly>() + + // 5000 + @Benchmark + fun base() = json.decodeFromString(Impl.serializer(), implString) + + // As of 1.3.x + // Baseline -- 1500 + // v1, no skip -- 2000 + // v2, with skip -- 3000 [withdrawn] + @Benchmark + fun poly() = json.decodeFromString(serializer, polyString) + +} diff --git a/build.gradle b/build.gradle index 69aa68dd..60b7e273 100644 --- a/build.gradle +++ b/build.gradle @@ -74,8 +74,7 @@ buildscript { // Various benchmarking stuff classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2" - classpath "me.champeau.gradle:jmh-gradle-plugin:0.5.3" - classpath "net.ltgt.gradle:gradle-apt-plugin:0.21" + classpath "me.champeau.jmh:jmh-gradle-plugin:0.6.6" } } diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt index 4afe9e74..8f1f02fd 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt @@ -96,7 +96,7 @@ public sealed class Json( */ public final override fun <T> decodeFromString(deserializer: DeserializationStrategy<T>, string: String): T { val lexer = StringJsonLexer(string) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val result = input.decodeSerializableValue(deserializer) lexer.expectEof() return result diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt index 4e055b23..14e70a42 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt @@ -24,7 +24,7 @@ internal class JsonPath { // Tombstone indicates that we are within a map, but the map key is currently being decoded. // It is also used to overwrite a previous map key to avoid memory leaks and misattribution. - object Tombstone + private object Tombstone /* * Serial descriptor, map key or the tombstone for map key diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt index ea65c48a..c1c91264 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt @@ -9,6 +9,7 @@ import kotlinx.serialization.* import kotlinx.serialization.descriptors.* import kotlinx.serialization.internal.* import kotlinx.serialization.json.* +import kotlin.jvm.* @Suppress("UNCHECKED_CAST") internal inline fun <T> JsonEncoder.encodePolymorphically( @@ -55,12 +56,13 @@ internal fun checkKind(kind: SerialKind) { } internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: DeserializationStrategy<T>): T { + // NB: changes in this method should be reflected in StreamingJsonDecoder#decodeSerializableValue if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) { return deserializer.deserialize(this) } + val discriminator = deserializer.descriptor.classDiscriminator(json) val jsonTree = cast<JsonObject>(decodeJsonElement(), deserializer.descriptor) - val discriminator = deserializer.descriptor.classDiscriminator(json) val type = jsonTree[discriminator]?.jsonPrimitive?.content val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type) ?: throwSerializerNotFound(type, jsonTree) @@ -69,7 +71,8 @@ internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer as DeserializationStrategy<T>) } -private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing { +@JvmName("throwSerializerNotFound") +internal fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing { val suffix = if (type == null) "missing class discriminator ('null')" else "class discriminator '$type'" diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt index bf229044..403e90de 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt @@ -9,6 +9,7 @@ import kotlinx.serialization.descriptors.* import kotlinx.serialization.encoding.* import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME +import kotlinx.serialization.internal.* import kotlinx.serialization.json.* import kotlinx.serialization.modules.* import kotlin.jvm.* @@ -21,11 +22,27 @@ internal open class StreamingJsonDecoder( final override val json: Json, private val mode: WriteMode, @JvmField internal val lexer: AbstractJsonLexer, - descriptor: SerialDescriptor + descriptor: SerialDescriptor, + discriminatorHolder: DiscriminatorHolder? ) : JsonDecoder, AbstractDecoder() { + // A mutable reference to the discriminator that have to be skipped when in optimistic phase + // of polymorphic serialization, see `decodeSerializableValue` + internal class DiscriminatorHolder(@JvmField var discriminatorToSkip: String?) + + private fun DiscriminatorHolder?.trySkip(unknownKey: String): Boolean { + if (this == null) return false + if (discriminatorToSkip == unknownKey) { + discriminatorToSkip = null + return true + } + return false + } + + override val serializersModule: SerializersModule = json.serializersModule private var currentIndex = -1 + private var discriminatorHolder: DiscriminatorHolder? = discriminatorHolder private val configuration = json.configuration private val elementMarker: JsonElementMarker? = if (configuration.explicitNulls) null else JsonElementMarker(descriptor) @@ -35,7 +52,40 @@ internal open class StreamingJsonDecoder( @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T { try { - return decodeSerializableValuePolymorphic(deserializer) + /* + * This is an optimized path over decodeSerializableValuePolymorphic(deserializer): + * dSVP reads the very next JSON tree into a memory as JsonElement and then runs TreeJsonDecoder over it + * in order to deal with an arbitrary order of keys, but with the price of additional memory pressure + * and CPU consumption. + * We would like to provide best possible performance for data produced by kotlinx.serialization + * itself, for that we do the following optimistic optimization: + * + * 0) Remember current position in the string + * 1) Read the very next key of JSON structure + * 2) If it matches* the descriminator key, read the value, remember current position + * 3) Return the value, recover an initial position + * (*) -- if it doesn't match, fallback to dSVP method. + */ + if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) { + return deserializer.deserialize(this) + } + + val discriminator = deserializer.descriptor.classDiscriminator(json) + val type = lexer.consumeLeadingMatchingValue(discriminator, configuration.isLenient) + var actualSerializer: DeserializationStrategy<out Any>? = null + if (type != null) { + actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type) + } + if (actualSerializer == null) { + // Fallback if we haven't found discriminator or serializer + return decodeSerializableValuePolymorphic<T>(deserializer as DeserializationStrategy<T>) + } + + discriminatorHolder = DiscriminatorHolder(discriminator) + @Suppress("UNCHECKED_CAST") + val result = actualSerializer.deserialize(this) as T + return result + } catch (e: MissingFieldException) { throw MissingFieldException(e.message + " at path: " + lexer.path.getPath(), e) } @@ -52,12 +102,13 @@ internal open class StreamingJsonDecoder( json, newMode, lexer, - descriptor + descriptor, + discriminatorHolder ) else -> if (mode == newMode && json.configuration.explicitNulls) { this } else { - StreamingJsonDecoder(json, newMode, lexer, descriptor) + StreamingJsonDecoder(json, newMode, lexer, descriptor, discriminatorHolder) } } } @@ -193,7 +244,7 @@ internal open class StreamingJsonDecoder( } private fun handleUnknown(key: String): Boolean { - if (configuration.ignoreUnknownKeys) { + if (configuration.ignoreUnknownKeys || discriminatorHolder.trySkip(key)) { lexer.skipElement(configuration.isLenient) } else { // Here we cannot properly update json path indicies diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt index 173e54a8..977347a5 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt @@ -283,6 +283,8 @@ internal abstract class AbstractJsonLexer { return current } + abstract fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? + fun peekString(isLenient: Boolean): String? { val token = peekNextToken() val string = if (isLenient) { diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt index 0ff980a2..9ccfbcc1 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt @@ -78,10 +78,10 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer( override fun consumeKeyString(): String { /* - * For strings we assume that escaped symbols are rather an exception, so firstly - * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', - * than do our pessimistic check for backslash and fallback to slow-path if necessary. - */ + * For strings we assume that escaped symbols are rather an exception, so firstly + * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', + * than do our pessimistic check for backslash and fallback to slow-path if necessary. + */ consumeNextToken(STRING) val current = currentPosition val closingQuote = source.indexOf('"', current) @@ -96,4 +96,22 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer( this.currentPosition = closingQuote + 1 return source.substring(current, closingQuote) } + + override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? { + val positionSnapshot = currentPosition + try { + // Malformed JSON, bailout + if (consumeNextToken() != TC_BEGIN_OBJ) return null + val firstKey = if (isLenient) consumeKeyString() else consumeStringLenientNotNull() + if (firstKey == keyToMatch) { + if (consumeNextToken() != TC_COLON) return null + val result = if (isLenient) consumeString() else consumeStringLenientNotNull() + return result + } + return null + } finally { + // Restore the position + currentPosition = positionSnapshot + } + } } diff --git a/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt b/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt new file mode 100644 index 00000000..d2f09f06 --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2017-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ +package kotlinx.serialization.features + +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlinx.serialization.modules.* +import kotlin.test.* + +class DefaultPolymorphicSerializerTest : JsonTestBase() { + + @Serializable + abstract class Project { + abstract val name: String + } + + @Serializable + data class DefaultProject(override val name: String, val type: String): Project() + + val module = SerializersModule { + polymorphic(Project::class) { + defaultDeserializer { DefaultProject.serializer() } + } + } + + private val json = Json { serializersModule = module } + + @Test + fun test() = parametrizedTest { + assertEquals(DefaultProject("example", "unknown"), + json.decodeFromString<Project>(""" {"type":"unknown","name":"example"}""", it)) + } + +} diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt index a88e264f..4352aa6b 100644 --- a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt +++ b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt @@ -67,7 +67,7 @@ abstract class JsonTestBase { } JsonTestingMode.TREE -> { val lexer = StringJsonLexer(source) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val tree = input.decodeJsonElement() lexer.expectEof() readJson(tree, deserializer) diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt index 3b83299c..be3a64db 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt @@ -61,7 +61,7 @@ public fun <T> Json.decodeFromStream( stream: InputStream ): T { val lexer = ReaderJsonLexer(stream) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val result = input.decodeSerializableValue(deserializer) lexer.expectEof() return result diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt index 79003082..3929c840 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt @@ -56,7 +56,7 @@ private class JsonIteratorWsSeparated<T>( private val deserializer: DeserializationStrategy<T> ) : Iterator<T> { override fun next(): T = - StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor) + StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null) .decodeSerializableValue(deserializer) override fun hasNext(): Boolean = lexer.isNotEof() @@ -75,7 +75,7 @@ private class JsonIteratorArrayWrapped<T>( } else { lexer.consumeNextToken(COMMA) } - val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null) return input.decodeSerializableValue(deserializer) } diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt index eabfd088..28ec2cfc 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt @@ -133,10 +133,10 @@ internal class ReaderJsonLexer( override fun consumeKeyString(): String { /* - * For strings we assume that escaped symbols are rather an exception, so firstly - * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', - * than do our pessimistic check for backslash and fallback to slow-path if necessary. - */ + * For strings we assume that escaped symbols are rather an exception, so firstly + * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', + * than do our pessimistic check for backslash and fallback to slow-path if necessary. + */ consumeNextToken(STRING) var current = currentPosition val closingQuote = indexOf('"', current) @@ -174,4 +174,7 @@ internal class ReaderJsonLexer( override fun appendRange(fromIndex: Int, toIndex: Int) { escapedString.append(_source, fromIndex, toIndex - fromIndex) } + + // Can be carefully implemented but postponed for now + override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? = null } diff --git a/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt b/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt index b576a2c1..0de89d9c 100644 --- a/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt +++ b/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt @@ -4,11 +4,11 @@ package kotlinx.serialization.features -import kotlinx.serialization.SerializationException -import kotlinx.serialization.StringData +import kotlinx.serialization.* import kotlinx.serialization.builtins.serializer import kotlinx.serialization.json.* import kotlinx.serialization.json.internal.BATCH_SIZE +import kotlinx.serialization.modules.* import kotlinx.serialization.test.* import org.junit.Test import java.io.ByteArrayInputStream @@ -85,4 +85,45 @@ class JsonJvmStreamsTest { } } + interface Poly + + @Serializable + @SerialName("Impl") + data class Impl(val str: String) : Poly + + @Test + fun testPolymorphismWhenCrossingBatchSizeNonLeadingKey() { + val json = Json { + serializersModule = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl::class, Impl.serializer()) + } + } + } + + val longString = "a".repeat(BATCH_SIZE - 5) + val string = """{"str":"$longString", "type":"Impl"}""" + val golden = Impl(longString) + + val deserialized = json.decodeViaStream(serializer<Poly>(), string) + assertEquals(golden, deserialized as Impl) + } + + @Test + fun testPolymorphismWhenCrossingBatchSize() { + val json = Json { + serializersModule = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl::class, Impl.serializer()) + } + } + } + + val aLotOfWhiteSpaces = " ".repeat(BATCH_SIZE - 5) + val string = """{$aLotOfWhiteSpaces"type":"Impl", "str":"value"}""" + val golden = Impl("value") + + val deserialized = json.decodeViaStream(serializer<Poly>(), string) + assertEquals(golden, deserialized as Impl) + } } |