summaryrefslogtreecommitdiff
path: root/common/testutils/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt
blob: 4a7b35134699e4e2347b1fd90b1c23d331e201dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.testutils

import android.net.netstats.provider.NetworkStatsProvider
import android.util.Log
import com.android.net.module.util.ArrayTrackRecord
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail

private const val DEFAULT_TIMEOUT_MS = 200L
const val TOKEN_ANY = -1

open class TestableNetworkStatsProvider(
    val defaultTimeoutMs: Long = DEFAULT_TIMEOUT_MS
) : NetworkStatsProvider() {
    sealed class CallbackType {
        data class OnRequestStatsUpdate(val token: Int) : CallbackType()
        data class OnSetWarningAndLimit(
            val iface: String,
            val warningBytes: Long,
            val limitBytes: Long
        ) : CallbackType()
        data class OnSetLimit(val iface: String, val limitBytes: Long) : CallbackType() {
            // Add getter for backward compatibility since old tests do not recognize limitBytes.
            val quotaBytes: Long
                get() = limitBytes
        }
        data class OnSetAlert(val quotaBytes: Long) : CallbackType()
    }

    private val TAG = this::class.simpleName
    val history = ArrayTrackRecord<CallbackType>().newReadHead()
    // See ReadHead#mark
    val mark get() = history.mark

    override fun onRequestStatsUpdate(token: Int) {
        Log.d(TAG, "onRequestStatsUpdate $token")
        history.add(CallbackType.OnRequestStatsUpdate(token))
    }

    override fun onSetWarningAndLimit(iface: String, warningBytes: Long, limitBytes: Long) {
        Log.d(TAG, "onSetWarningAndLimit $iface $warningBytes $limitBytes")
        history.add(CallbackType.OnSetWarningAndLimit(iface, warningBytes, limitBytes))
    }

    override fun onSetLimit(iface: String, quotaBytes: Long) {
        Log.d(TAG, "onSetLimit $iface $quotaBytes")
        history.add(CallbackType.OnSetLimit(iface, quotaBytes))
    }

    override fun onSetAlert(quotaBytes: Long) {
        Log.d(TAG, "onSetAlert $quotaBytes")
        history.add(CallbackType.OnSetAlert(quotaBytes))
    }

    fun expectOnRequestStatsUpdate(token: Int, timeout: Long = defaultTimeoutMs): Int {
        val event = history.poll(timeout)
        assertTrue(event is CallbackType.OnRequestStatsUpdate)
        if (token != TOKEN_ANY) {
            assertEquals(token, event.token)
        }
        return event.token
    }

    fun expectOnSetLimit(iface: String, quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
        assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(timeout))
    }

    fun expectOnSetAlert(quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
        assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(timeout))
    }

    fun pollForNextCallback(timeout: Long = defaultTimeoutMs) =
        history.poll(timeout) ?: fail("Did not receive callback after ${timeout}ms")

    inline fun <reified T : CallbackType> expectCallback(
        timeout: Long = defaultTimeoutMs,
        predicate: (T) -> Boolean = { true }
    ): T {
        return pollForNextCallback(timeout).also { assertTrue(it is T && predicate(it)) } as T
    }

    // Expects a callback of the specified type matching the predicate within the timeout.
    // Any callback that doesn't match the predicate will be skipped. Fails only if
    // no matching callback is received within the timeout.
    // TODO : factorize the code for this with the identical call in TestableNetworkCallback.
    // There should be a common superclass doing this generically.
    // TODO : have a better error message to have this fail. Right now the failure when no
    // matching callback arrives comes from the casting to a non-nullable T.
    // TODO : in fact, completely removing this method and have clients use
    // history.poll(timeout, index, predicate) directly might be simpler.
    inline fun <reified T : CallbackType> eventuallyExpect(
        timeoutMs: Long = defaultTimeoutMs,
        from: Int = mark,
        crossinline predicate: (T) -> Boolean = { true }
    ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T

    fun drainCallbacks() {
        history.mark = history.size
    }

    @JvmOverloads
    fun assertNoCallback(timeout: Long = defaultTimeoutMs) {
        val cb = history.poll(timeout)
        cb?.let { fail("Expected no callback but got $cb") }
    }
}