summaryrefslogtreecommitdiff
path: root/common/testutils/hostdevice/com/android/net/module/util/TrackRecord.kt
blob: f24e4f184d1ca925f69c28694ea4396b3478b2af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.net.module.util

import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.locks.StampedLock
import kotlin.concurrent.withLock

/**
 * A List that additionally offers the ability to append via the add() method, and to retrieve
 * an element by its index optionally waiting for it to become available.
 */
interface TrackRecord<E> : List<E> {
    /**
     * Adds an element to this queue, waking up threads waiting for one. Returns true, as
     * per the contract for List.
     */
    fun add(e: E): Boolean

    /**
     * Returns the first element after {@param pos}, possibly blocking until one is available, or
     * null if no such element can be found within the timeout.
     * If a predicate is given, only elements matching the predicate are returned.
     *
     * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
     * @param pos the position at which to start polling.
     * @param predicate an optional predicate to filter elements to be returned.
     * @return an element matching the predicate, or null if timeout.
     */
    fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean = { true }): E?
}

/**
 * A thread-safe implementation of TrackRecord that is backed by an ArrayList.
 *
 * This class also supports the creation of a read-head for easier single-thread access.
 * Refer to the documentation of {@link ArrayTrackRecord.ReadHead}.
 */
class ArrayTrackRecord<E> : TrackRecord<E> {
    private val lock = ReentrantLock()
    private val condition = lock.newCondition()
    // Backing store. This stores the elements in this ArrayTrackRecord.
    private val elements = ArrayList<E>()

    // The list iterator for RecordingQueue iterates over a snapshot of the collection at the
    // time the operator is created. Because TrackRecord is only ever mutated by appending,
    // that makes this iterator thread-safe as it sees an effectively immutable List.
    class ArrayTrackRecordIterator<E>(
        private val list: ArrayList<E>,
        start: Int,
        private val end: Int
    ) : ListIterator<E> {
        var index = start
        override fun hasNext() = index < end
        override fun next() = list[index++]
        override fun hasPrevious() = index > 0
        override fun nextIndex() = index + 1
        override fun previous() = list[--index]
        override fun previousIndex() = index - 1
    }

    // List<E> implementation
    override val size get() = lock.withLock { elements.size }
    override fun contains(element: E) = lock.withLock { elements.contains(element) }
    override fun containsAll(elements: Collection<E>) = lock.withLock {
        this.elements.containsAll(elements)
    }
    override operator fun get(index: Int) = lock.withLock { elements[index] }
    override fun indexOf(element: E): Int = lock.withLock { elements.indexOf(element) }
    override fun lastIndexOf(element: E): Int = lock.withLock { elements.lastIndexOf(element) }
    override fun isEmpty() = lock.withLock { elements.isEmpty() }
    override fun listIterator(index: Int) = ArrayTrackRecordIterator(elements, index, size)
    override fun listIterator() = listIterator(0)
    override fun iterator() = listIterator()
    override fun subList(fromIndex: Int, toIndex: Int): List<E> = lock.withLock {
        elements.subList(fromIndex, toIndex)
    }

    // TrackRecord<E> implementation
    override fun add(e: E): Boolean {
        lock.withLock {
            elements.add(e)
            condition.signalAll()
        }
        return true
    }
    override fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean) = lock.withLock {
        elements.getOrNull(pollForIndexReadLocked(timeoutMs, pos, predicate))
    }

    // For convenience
    fun getOrNull(pos: Int, predicate: (E) -> Boolean) = lock.withLock {
        if (pos < 0 || pos > size) null else elements.subList(pos, size).find(predicate)
    }

    // Returns the index of the next element whose position is >= pos matching the predicate, if
    // necessary waiting until such a time that such an element is available, with a timeout.
    // If no such element is found within the timeout -1 is returned.
    private fun pollForIndexReadLocked(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean): Int {
        val deadline = System.currentTimeMillis() + timeoutMs
        var index = pos
        do {
            while (index < elements.size) {
                if (predicate(elements[index])) return index
                ++index
            }
        } while (condition.await(deadline - System.currentTimeMillis()))
        return -1
    }

    /**
     * Returns a ReadHead over this ArrayTrackRecord. The returned ReadHead is tied to the
     * current thread.
     */
    fun newReadHead() = ReadHead()

    /**
     * ReadHead is an object that helps users of ArrayTrackRecord keep track of how far
     * it has read this far in the ArrayTrackRecord. A ReadHead is always associated with
     * a single instance of ArrayTrackRecord. Multiple ReadHeads can be created and used
     * on the same instance of ArrayTrackRecord concurrently, and the ArrayTrackRecord
     * instance can also be used concurrently. ReadHead maintains the current index that is
     * the next to be read, and calls this the "mark".
     *
     * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue.
     * It can be called repeatedly and will return the elements as they arrive.
     *
     * Intended usage looks something like this :
     * val TrackRecord<MyObject> record = ArrayTrackRecord().newReadHead()
     * Thread().start {
     *   // do stuff
     *   record.add(something)
     *   // do stuff
     * }
     *
     * val obj1 = record.poll(timeout)
     * // do something with obj1
     * val obj2 = record.poll(timeout)
     * // do something with obj2
     *
     * The point is that the caller does not have to track the mark like it would have to if
     * it was using ArrayTrackRecord directly.
     *
     * Thread safety :
     * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
     * inherits its thread-safe properties for all the TrackRecord methods.
     *
     * Poll() operates under its own set of rules that only allow execution on multiple threads
     * within constrained boundaries, and never concurrently or pseudo-concurrently. This is
     * because concurrent calls to poll() fundamentally do not make sense. poll() will move
     * the mark according to what events remained to be read by this read head, and therefore
     * if multiple threads were calling poll() concurrently on the same ReadHead, what
     * happens to the mark and the return values could not be useful because there is no way to
     * provide either a guarantee not to skip objects nor a guarantee about the mark position at
     * the exit of poll(). This is even more true in the presence of a predicate to filter
     * returned elements, because one thread might be filtering out the events the other is
     * interested in. For this reason, this class will fail-fast if any concurrent access is
     * detected with ConcurrentAccessException.
     * It is possible to use poll() on different threads as long as the following can be
     * guaranteed : one thread must call poll() for the last time, then execute a write barrier,
     * then the other thread must execute a read barrier before calling poll() for the first time.
     * This allows in particular to call poll in @Before and @After methods in JUnit unit tests,
     * because JUnit will enforce those barriers by creating the testing thread after executing
     * @Before and joining the thread after executing @After.
     *
     * peek() can be used by multiple threads concurrently, but only if no thread is calling
     * poll() outside of the boundaries above. For simplicity, it can be considered that peek()
     * is safe to call only when poll() is safe to call.
     *
     * Polling concurrently from the same ArrayTrackRecord is supported by creating multiple
     * ReadHeads on the same instance of ArrayTrackRecord (or of course by using ArrayTrackRecord
     * directly). Each ReadHead is then guaranteed to see all events always and
     * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
     * for details. Be careful to create each ReadHead on the thread it is meant to be used on, or
     * to have a clear synchronization point between creation and use.
     *
     * Users of a ReadHead can ask for the current position of the mark at any time, on a thread
     * where it's safe to call peek(). This mark can be used later to replay the history of events
     * either on this ReadHead, on the associated ArrayTrackRecord or on another ReadHead
     * associated with the same ArrayTrackRecord. It might look like this in the reader thread :
     *
     * val markAtStart = record.mark
     * // Start processing interesting events
     * while (val element = record.poll(timeout) { it.isInteresting() }) {
     *   // Do something with element
     * }
     * // Look for stuff that happened while searching for interesting events
     * val firstElementReceived = record.getOrNull(markAtStart)
     * val firstSpecialElement = record.getOrNull(markAtStart) { it.isSpecial() }
     * // Get the first special element since markAtStart, possibly blocking until one is available
     * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() }
     */
    inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord {
        // This lock only controls access to the readHead member below. The ArrayTrackRecord
        // object has its own synchronization following different (and more usual) semantics.
        // See the comment on the ReadHead class for details.
        private val slock = StampedLock()
        private var readHead = 0

        // A special mark used to track the start of the last poll() operation.
        private var pollMark = 0

        /**
         * @return the current value of the mark.
         */
        var mark
            get() = checkThread { readHead }
            set(v: Int) = rewind(v)
        fun rewind(v: Int) {
            val stamp = slock.tryWriteLock()
            if (0L == stamp) concurrentAccessDetected()
            readHead = v
            pollMark = v
            slock.unlockWrite(stamp)
        }

        private fun <T> checkThread(r: (Long) -> T): T {
            // tryOptimisticRead is a read barrier, guarantees writes from other threads are visible
            // after it
            val stamp = slock.tryOptimisticRead()
            val result = r(stamp)
            // validate also performs a read barrier, guaranteeing that if validate returns true,
            // then any change either happens-before tryOptimisticRead, or happens-after validate.
            if (!slock.validate(stamp)) concurrentAccessDetected()
            return result
        }

        private fun concurrentAccessDetected(): Nothing {
            throw ConcurrentModificationException(
                    "ReadHeads can't be used concurrently. Check your threading model.")
        }

        /**
         * Returns the first element after the mark, optionally blocking until one is available, or
         * null if no such element can be found within the timeout.
         * If a predicate is given, only elements matching the predicate are returned.
         *
         * Upon return the mark will be set to immediately after the returned element, or after
         * the last element in the queue if null is returned. This means this method will always
         * skip elements that do not match the predicate, even if it returns null.
         *
         * This method can only be used by the thread that created this ManagedRecordingQueue.
         * If used on another thread, this throws IllegalStateException.
         *
         * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
         * @param predicate an optional predicate to filter elements to be returned.
         * @return an element matching the predicate, or null if timeout.
         */
        fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? {
            val stamp = slock.tryWriteLock()
            if (0L == stamp) concurrentAccessDetected()
            pollMark = readHead
            try {
                lock.withLock {
                    val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
                    readHead = if (index < 0) size else index + 1
                    return getOrNull(index)
                }
            } finally {
                slock.unlockWrite(stamp)
            }
        }

        /**
         * Returns a list of events that were observed since the last time poll() was called on this
         * ReadHead.
         *
         * @return list of events since poll() was called.
         */
        fun backtrace(): List<E> {
            val stamp = slock.tryReadLock()
            if (0L == stamp) concurrentAccessDetected()

            try {
                lock.withLock {
                    return ArrayList(subList(pollMark, mark))
                }
            } finally {
                slock.unlockRead(stamp)
            }
        }

        /**
         * Returns the first element after the mark or null. This never blocks.
         *
         * This method is subject to threading restrictions. It can be used concurrently on
         * multiple threads but not if any other thread might be executing poll() at the same
         * time. See the class comment for details.
         */
        fun peek(): E? = checkThread { getOrNull(readHead) }
    }
}

// Private helper
private fun Condition.await(timeoutMs: Long) = this.await(timeoutMs, TimeUnit.MILLISECONDS)