aboutsummaryrefslogtreecommitdiff
path: root/atomicfu/src/jvmMain/kotlin/kotlinx/atomicfu/LockFreedomTestEnvironment.kt
blob: 0208febc48bfdfda5df7dbde5ad050c163184531 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
/*
 * Copyright 2017-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
 */

@file:Suppress("RedundantVisibilityModifier")

package kotlinx.atomicfu

import java.util.*
import java.util.concurrent.atomic.*
import java.util.concurrent.locks.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*

private const val PAUSE_EVERY_N_STEPS = 1000
private const val STALL_LIMIT_MS = 15_000L // 15s
private const val SHUTDOWN_CHECK_MS = 10L // 10ms

private const val STATUS_DONE = Int.MAX_VALUE

private const val MAX_PARK_NANOS = 1_000_000L // part for at most 1ms just in case of loosing unpark signal

/**
 * Environment for performing lock-freedom tests for lock-free data structures
 * that are written with [atomic] variables.
 */
public open class LockFreedomTestEnvironment(
    private val name: String,
    private val allowSuspendedThreads: Int = 0
) {
    private val interceptor = Interceptor()
    private val threads = mutableListOf<TestThread>()
    private val performedOps = LongAdder()
    private val uncaughtException = AtomicReference<Throwable?>()
    private var started = false
    private var performedResumes = 0

    @Volatile
    private var completed = false
    private val onCompletion = mutableListOf<() -> Unit>()

    private val ueh = Thread.UncaughtExceptionHandler { t, e ->
        synchronized(System.out) {
            println("Uncaught exception in thread $t")
            e.printStackTrace(System.out)
            uncaughtException.compareAndSet(null, e)
        }
    }

    // status < 0             - inv paused thread id
    // status >= 0            - no. of performed resumes so far (==last epoch)
    // status == STATUS_DONE - done working
    private val status = AtomicInteger()
    private val globalPauseProgress = AtomicInteger()
    private val suspendedThreads = ArrayList<TestThread>()

    @Volatile
    private var isActive = true

    // ---------- API ----------

    /**
     * Starts lock-freedom test for a given duration in seconds,
     * invoking [progress] every second (it will be invoked `seconds + 1` times).
     */
    public fun performTest(seconds: Int, progress: () -> Unit = {}) {
        check(isActive) { "Can perform test at most once on this instance" }
        println("=== $name")
        val minThreads = 2 + allowSuspendedThreads
        check(threads.size >= minThreads) { "Must define at least $minThreads test threads" }
        lockAndSetInterceptor(interceptor)
        started = true
        var nextTime = System.currentTimeMillis()
        threads.forEach { thread ->
            thread.setUncaughtExceptionHandler(ueh)
            thread.lastOpTime = nextTime
            thread.start()
        }
        try {
            var second = 0
            while (uncaughtException.get() == null) {
                waitUntil(nextTime)
                println("--- $second: Performed ${performedOps.sum()} operations${resumeStr()}")
                progress()
                checkStalled()
                if (++second > seconds) break
                nextTime += 1000L
            }
        } finally {
            complete()
        }
        println("------ Done with ${performedOps.sum()} operations${resumeStr()}")
        progress()
    }

    private fun complete() {
        val activeNonPausedThreads: MutableMap<TestThread, Array<StackTraceElement>> = mutableMapOf()
        val shutdownDeadline = System.currentTimeMillis() + STALL_LIMIT_MS
        try {
            completed = true
            // perform custom completion blocks. For testing of things like channels, these custom completion
            // blocks close all the channels, so that all suspended coroutines shall get resumed.
            onCompletion.forEach { it() }
            // signal shutdown to all threads (non-paused threads will terminate)
            isActive = false
            // wait for threads to terminate
            while (System.currentTimeMillis() < shutdownDeadline) {
                // Check all threads while shutting down:
                // All terminated threads are considered to make progress for the purpose of resuming stalled ones
                activeNonPausedThreads.clear()
                for (t in threads) {
                    when {
                        !t.isAlive -> t.makeProgress(getPausedEpoch()) // not alive - makes progress
                        t.index.inv() == status.get() -> {} // active, paused -- skip
                        else -> {
                            val stackTrace = t.stackTrace
                            if (t.isAlive) activeNonPausedThreads[t] = stackTrace
                        }
                    }
                }
                if (activeNonPausedThreads.isEmpty()) break
                checkStalled()
                Thread.sleep(SHUTDOWN_CHECK_MS)
            }
            activeNonPausedThreads.forEach { (t, stackTrack) ->
                println("=== $t had failed to shutdown in time")
                stackTrack.forEach { println("\tat $it") }
            }
        } finally {
            shutdown(shutdownDeadline)
        }
        // if no other exception was throws & we had threads that did not shut down -- still fails
        if (activeNonPausedThreads.isNotEmpty()) error("Some threads had failed to shutdown in time")
    }

    private fun shutdown(shutdownDeadline: Long) {
        // forcefully unpause paused threads to shut them down (if any left)
        val curStatus = status.getAndSet(STATUS_DONE)
        if (curStatus < 0) LockSupport.unpark(threads[curStatus.inv()])
        threads.forEach {
            val remaining = shutdownDeadline - System.currentTimeMillis()
            if (remaining > 0) it.join(remaining)
        }
        // abort waiting threads (if still any left)
        threads.forEach { it.abortWait() }
        // cleanup & be done
        unlockAndResetInterceptor(interceptor)
        uncaughtException.get()?.let { throw it }
        threads.find { it.isAlive }?.let { dumpThreadsError("A thread is still alive: $it")}
    }

    private fun checkStalled() {
        val stallLimit = System.currentTimeMillis() - STALL_LIMIT_MS
        val stalled = threads.filter { it.lastOpTime < stallLimit }
        if (stalled.isNotEmpty()) dumpThreadsError("Progress stalled in threads ${stalled.map { it.name }}")
    }

    private fun resumeStr(): String {
        val resumes = performedResumes
        return if (resumes == 0) "" else " (pause/resumes $resumes)"
    }

    private fun waitUntil(nextTime: Long) {
        while (true) {
            val curTime = System.currentTimeMillis()
            if (curTime >= nextTime) break
            Thread.sleep(nextTime - curTime)
        }
    }

    private fun dumpThreadsError(message: String) : Nothing {
        val traces = threads.associate { it to it.stackTrace }
        println("!!! $message")
        println("=== Dumping live thread stack traces")
        for ((thread, trace) in traces) {
            if (trace.isEmpty()) continue
            println("Thread \"${thread.name}\" ${thread.state}")
            for (t in trace) println("\tat ${t.className}.${t.methodName}(${t.fileName}:${t.lineNumber})")
            println()
        }
        println("===")
        error(message)
    }

    /**
     * Returns true when test was completed.
     * Sets to true before calling [onCompletion] blocks.
     */
    public val isCompleted: Boolean get() = completed

    /**
     * Performs a given block of code on test's completion
     */
    public fun onCompletion(block: () -> Unit) {
        onCompletion += block
    }

    /**
     * Creates a new test thread in this environment that is executes a given lock-free [operation]
     * in a loop while this environment [isActive].
     */
    public fun testThread(name: String? = null, operation: suspend TestThread.() -> Unit): TestThread =
        TestThread(name, operation)

    /**
     * Test thread.
     */
    @Suppress("LeakingThis")
    public inner class TestThread internal constructor(
        name: String?,
        private val operation: suspend TestThread.() -> Unit
    ) : Thread(composeThreadName(name)) {
        internal val index: Int

        internal @Volatile var lastOpTime = 0L
        internal @Volatile var pausedEpoch = -1

        private val random = Random()

        // thread-local stuff
        private var operationEpoch = -1
        private var progressEpoch = -1
        private var sink = 0

        init {
            check(!started)
            index = threads.size
            threads += this
        }

        public override fun run() {
            while (isActive) {
                callOperation()
            }
        }

        /**
         * Use it to insert an arbitrary intermission between lock-free operations.
         */
        public inline fun <T> intermission(block: () -> T): T {
            afterLockFreeOperation()
            return try { block() }
                finally { beforeLockFreeOperation() }
        }

        @PublishedApi
        internal fun beforeLockFreeOperation() {
            operationEpoch = getPausedEpoch()
        }

        @PublishedApi
        internal fun afterLockFreeOperation() {
            makeProgress(operationEpoch)
            lastOpTime = System.currentTimeMillis()
            performedOps.add(1)
        }

        internal fun makeProgress(epoch: Int) {
            if (epoch <= progressEpoch) return
            progressEpoch = epoch
            val total = globalPauseProgress.incrementAndGet()
            if (total >= threads.size - 1) {
                check(total == threads.size - 1)
                check(globalPauseProgress.compareAndSet(threads.size - 1, 0))
                resumeImpl()
            }
        }

        /**
         * Inserts random spin wait between multiple lock-free operations in [operation].
         */
        public fun randomSpinWaitIntermission() {
            intermission {
                if (random.nextInt(100) < 95) return // be quick, no wait 95% of time
                do {
                    val x = random.nextInt(100)
                    repeat(x) { sink += it }
                } while (x >= 90)
            }
        }

        internal fun stepImpl() {
            if (random.nextInt(PAUSE_EVERY_N_STEPS) == 0) pauseImpl()
        }

        internal fun pauseImpl() {
            while (true) {
                val curStatus = status.get()
                if (curStatus < 0 || curStatus == STATUS_DONE) return // some other thread paused or done
                pausedEpoch = curStatus + 1
                val newStatus = index.inv()
                if (status.compareAndSet(curStatus, newStatus)) {
                    while (status.get() == newStatus) LockSupport.parkNanos(MAX_PARK_NANOS) // wait
                    return
                }
            }
        }

        // ----- Lightweight support for suspending operations -----

        private fun callOperation() {
            beforeLockFreeOperation()
            beginRunningOperation()
            val result = operation.startCoroutineUninterceptedOrReturn(this, completion)
            when {
                result === Unit -> afterLockFreeOperation() // operation completed w/o suspension -- done
                result === COROUTINE_SUSPENDED -> waitUntilCompletion() // operation had suspended
                else -> error("Unexpected result of operation: $result")
            }
            try {
                doneRunningOperation()
            } catch(e: IllegalStateException) {
                throw IllegalStateException("${e.message}; original start result=$result", e)
            }
        }

        private var runningOperation = false
        private var result: Result<Any?>? = null
        private var continuation: Continuation<Any?>? = null

        private fun waitUntilCompletion() {
            try {
                while (true) {
                    afterLockFreeOperation()
                    val result: Result<Any?> = waitForResult()
                    val continuation = takeContinuation()
                    if (continuation == null) { // done
                        check(result.getOrThrow() === Unit)
                        return
                    }
                    removeSuspended(this)
                    beforeLockFreeOperation()
                    continuation.resumeWith(result)
                }
            } finally {
                removeSuspended(this)
            }
        }

        private fun beginRunningOperation() {
            runningOperation = true
            result = null
            continuation = null
        }

        @Synchronized
        private fun doneRunningOperation() {
            check(runningOperation) { "Should be running operation" }
            check(result == null && continuation == null) {
                "Callback invoked with result=$result, continuation=$continuation"
            }
            runningOperation = false
        }

        @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
        @Synchronized
        private fun resumeWith(result: Result<Any?>, continuation: Continuation<Any?>?) {
            check(runningOperation) { "Should be running operation" }
            check(this.result == null && this.continuation == null) {
                "Resumed again with result=$result, continuation=$continuation, when this: result=${this.result}, continuation=${this.continuation}"
            }
            this.result = result
            this.continuation = continuation
            (this as Object).notifyAll()
        }

        @Suppress("RESULT_CLASS_IN_RETURN_TYPE", "PLATFORM_CLASS_MAPPED_TO_KOTLIN")
        @Synchronized
        private fun waitForResult(): Result<Any?> {
            while (true) {
                val result = this.result
                if (result != null) return result
                val index = addSuspended(this)
                if (index < allowSuspendedThreads) {
                    // This suspension was permitted, so assume progress is happening while it is suspended
                    makeProgress(getPausedEpoch())
                }
                (this as Object).wait(10) // at most 10 ms
            }
        }

        @Synchronized
        private fun takeContinuation(): Continuation<Any?>? =
            continuation.also {
                this.result = null
                this.continuation = null
            }

        @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
        @Synchronized
        fun abortWait() {
            this.result = Result.failure(IllegalStateException("Aborted at the end of test"))
            (this as Object).notifyAll()
        }

        private val interceptor: CoroutineContext = object : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
            override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
                Continuation<T>(this) {
                    @Suppress("UNCHECKED_CAST")
                    resumeWith(it, continuation as Continuation<Any?>)
                }
        }

        private val completion = Continuation<Unit>(interceptor) {
            resumeWith(it, null)
        }
    }

    // ---------- Implementation ----------

    @Synchronized
    private fun addSuspended(thread: TestThread): Int {
        val index = suspendedThreads.indexOf(thread)
        if (index >= 0) return index
        suspendedThreads.add(thread)
        return suspendedThreads.size - 1
    }

    @Synchronized
    private fun removeSuspended(thread: TestThread) {
        suspendedThreads.remove(thread)
    }

    private fun getPausedEpoch(): Int {
        while (true) {
            val curStatus = status.get()
            if (curStatus >= 0) return -1 // not paused
            val thread = threads[curStatus.inv()]
            val pausedEpoch = thread.pausedEpoch
            if (curStatus == status.get()) return pausedEpoch
        }
    }

    internal fun step() {
        val thread = Thread.currentThread() as? TestThread ?: return
        thread.stepImpl()
    }

    private fun resumeImpl() {
        while (true) {
            val curStatus = status.get()
            if (curStatus == STATUS_DONE) return // done
            check(curStatus < 0)
            val thread = threads[curStatus.inv()]
            performedResumes = thread.pausedEpoch
            if (status.compareAndSet(curStatus, thread.pausedEpoch)) {
                LockSupport.unpark(thread)
                return
            }
        }
    }

    private fun composeThreadName(threadName: String?): String {
        if (threadName != null) return "$name-$threadName"
        return name + "-${threads.size + 1}"
    }

    private inner class Interceptor : AtomicOperationInterceptor() {
        override fun <T> beforeUpdate(ref: AtomicRef<T>) = step()
        override fun beforeUpdate(ref: AtomicInt) = step()
        override fun beforeUpdate(ref: AtomicLong) = step()
        override fun <T> afterSet(ref: AtomicRef<T>, newValue: T) = step()
        override fun afterSet(ref: AtomicInt, newValue: Int) = step()
        override fun afterSet(ref: AtomicLong, newValue: Long) = step()
        override fun <T> afterRMW(ref: AtomicRef<T>, oldValue: T, newValue: T) = step()
        override fun afterRMW(ref: AtomicInt, oldValue: Int, newValue: Int) = step()
        override fun afterRMW(ref: AtomicLong, oldValue: Long, newValue: Long) = step()
        override fun toString(): String = "LockFreedomTestEnvironment($name)"
    }
}

/**
 * Manual pause for on-going lock-free operation in a specified piece of code.
 * Use it for targeted debugging of specific places in code. It does nothing
 * when invoked outside of test thread.
 *
 * **Don't use it in production code.**
 */
public fun pauseLockFreeOp() {
    val thread = Thread.currentThread() as? LockFreedomTestEnvironment.TestThread ?: return
    thread.pauseImpl()
}