aboutsummaryrefslogtreecommitdiff
path: root/agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt
blob: 1694be58b8cc2331ad97ad23c7aec2493cff32da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
// Copyright 2021 Code Intelligence GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.code_intelligence.jazzer.instrumentor

import com.code_intelligence.jazzer.api.HookType
import org.objectweb.asm.Handle
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.commons.LocalVariablesSorter
import java.util.concurrent.atomic.AtomicBoolean

internal fun makeHookMethodVisitor(
    access: Int,
    descriptor: String?,
    methodVisitor: MethodVisitor?,
    hooks: Iterable<Hook>,
    java6Mode: Boolean,
    random: DeterministicRandom,
): MethodVisitor {
    return HookMethodVisitor(access, descriptor, methodVisitor, hooks, java6Mode, random).lvs
}

private class HookMethodVisitor(
    access: Int,
    descriptor: String?,
    methodVisitor: MethodVisitor?,
    hooks: Iterable<Hook>,
    private val java6Mode: Boolean,
    private val random: DeterministicRandom,
) : MethodVisitor(Instrumentor.ASM_API_VERSION, methodVisitor) {

    companion object {
        private val showUnsupportedHookWarning = AtomicBoolean(true)
    }

    val lvs = object : LocalVariablesSorter(Instrumentor.ASM_API_VERSION, access, descriptor, this) {
        override fun updateNewLocals(newLocals: Array<Any>) {
            // The local variables involved in calling hooks do not need to outlive the current
            // basic block and should thus not appear in stack map frames. By requesting the
            // LocalVariableSorter to fill their entries in stack map frames with TOP, they will
            // be treated like an unused local variable slot.
            newLocals.fill(Opcodes.TOP)
        }
    }

    private val hooks = hooks.groupBy { hook ->
        var hookKey = "${hook.hookType}#${hook.targetInternalClassName}#${hook.targetMethodName}"
        if (hook.targetMethodDescriptor != null)
            hookKey += "#${hook.targetMethodDescriptor}"
        hookKey
    }

    override fun visitMethodInsn(
        opcode: Int,
        owner: String,
        methodName: String,
        methodDescriptor: String,
        isInterface: Boolean,
    ) {
        if (!isMethodInvocationOp(opcode)) {
            mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
            return
        }
        handleMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
    }

    fun handleMethodInsn(
        opcode: Int,
        owner: String,
        methodName: String,
        methodDescriptor: String,
        isInterface: Boolean,
    ) {
        val matchingHooks = findMatchingHooks(owner, methodName, methodDescriptor)

        if (matchingHooks.isEmpty()) {
            mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
            return
        }

        val paramDescriptors = extractParameterTypeDescriptors(methodDescriptor)
        val localObjArr = storeMethodArguments(paramDescriptors)
        // If the method we're hooking is not static there is now a reference to
        // the object the method was invoked on at the top of the stack.
        // If the method is static, that object is missing. We make up for it by pushing a null ref.
        if (opcode == Opcodes.INVOKESTATIC) {
            mv.visitInsn(Opcodes.ACONST_NULL)
        }

        // Save the owner object to a new local variable
        val ownerDescriptor = "L$owner;"
        val localOwnerObj = lvs.newLocal(Type.getType(ownerDescriptor))
        mv.visitVarInsn(Opcodes.ASTORE, localOwnerObj) // consume objectref
        // We now removed all values for the original method call from the operand stack
        // and saved them to local variables.

        val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor)
        // Create a local variable to store the return value
        val localReturnObj = lvs.newLocal(Type.getType(getWrapperTypeDescriptor(returnTypeDescriptor)))

        matchingHooks.forEachIndexed { index, hook ->
            // The hookId is used to identify a call site.
            val hookId = random.nextInt()

            // Start to build the arguments for the hook method.
            if (methodName == "<init>") {
                // Constructor is invoked on an uninitialized object, and that's still on the stack.
                // In case of REPLACE pop it from the stack and replace it afterwards with the returned
                // one from the hook.
                if (hook.hookType == HookType.REPLACE) {
                    mv.visitInsn(Opcodes.POP)
                }
                // Special case for constructors:
                // We cannot create a MethodHandle for a constructor, so we push null instead.
                mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
                // Only pass the this object if it has been initialized by the time the hook is invoked.
                if (hook.hookType == HookType.AFTER) {
                    mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj)
                } else {
                    mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
                }
            } else {
                // Push a MethodHandle representing the hooked method.
                val handleOpcode = when (opcode) {
                    Opcodes.INVOKEVIRTUAL -> Opcodes.H_INVOKEVIRTUAL
                    Opcodes.INVOKEINTERFACE -> Opcodes.H_INVOKEINTERFACE
                    Opcodes.INVOKESTATIC -> Opcodes.H_INVOKESTATIC
                    Opcodes.INVOKESPECIAL -> Opcodes.H_INVOKESPECIAL
                    else -> -1
                }
                if (java6Mode) {
                    // MethodHandle constants (type 15) are not supported in Java 6 class files (major version 50).
                    mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
                } else {
                    mv.visitLdcInsn(
                        Handle(
                            handleOpcode,
                            owner,
                            methodName,
                            methodDescriptor,
                            isInterface
                        )
                    ) // push MethodHandle
                }
                // Stack layout: ... | MethodHandle (objectref)
                // Push the owner object again
                mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj)
            }
            // Stack layout: ... | MethodHandle (objectref) | owner (objectref)
            // Push a reference to our object array with the saved arguments
            mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
            // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref)
            // Push the hook id
            mv.visitLdcInsn(hookId)
            // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
            // How we proceed depends on the type of hook we want to implement
            when (hook.hookType) {
                HookType.BEFORE -> {
                    // Call the hook method
                    mv.visitMethodInsn(
                        Opcodes.INVOKESTATIC,
                        hook.hookInternalClassName,
                        hook.hookMethodName,
                        hook.hookMethodDescriptor,
                        false
                    )

                    // Call the original method if this is the last BEFORE hook. If not, the original method will be
                    // called by the next AFTER hook.
                    if (index == matchingHooks.lastIndex) {
                        // Stack layout: ...
                        // Push the values for the original method call onto the stack again
                        if (opcode != Opcodes.INVOKESTATIC) {
                            mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object
                        }
                        loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments
                        // Stack layout: ... | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ...
                        mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
                    }
                }
                HookType.REPLACE -> {
                    // Call the hook method
                    mv.visitMethodInsn(
                        Opcodes.INVOKESTATIC,
                        hook.hookInternalClassName,
                        hook.hookMethodName,
                        hook.hookMethodDescriptor,
                        false
                    )
                    // Stack layout: ... | [return value (primitive/objectref)]
                    // Check if we need to process the return value
                    if (returnTypeDescriptor != "V") {
                        val hookMethodReturnType = extractReturnTypeDescriptor(hook.hookMethodDescriptor)
                        // if the hook method's return type is primitive we don't need to unwrap or cast it
                        if (!isPrimitiveType(hookMethodReturnType)) {
                            // Check if the returned object type is different than the one that should be returned
                            // If a primitive should be returned we check it's wrapper type
                            val expectedType = getWrapperTypeDescriptor(returnTypeDescriptor)
                            if (expectedType != hookMethodReturnType) {
                                // Cast object
                                mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(expectedType))
                            }
                            // Check if we need to unwrap the returned object
                            unwrapTypeIfPrimitive(returnTypeDescriptor)
                        }
                    }
                }
                HookType.AFTER -> {
                    // Call the original method before the first AFTER hook
                    if (index == 0 || matchingHooks[index - 1].hookType != HookType.AFTER) {
                        // Push the values for the original method call again onto the stack
                        if (opcode != Opcodes.INVOKESTATIC) {
                            mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object
                        }
                        loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments
                        // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
                        //                   | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ...
                        mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
                        if (returnTypeDescriptor == "V") {
                            // If the method didn't return anything, we push a nullref as placeholder
                            mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
                        }
                        // Wrap return value if it is a primitive type
                        wrapTypeIfPrimitive(returnTypeDescriptor)
                        mv.visitVarInsn(Opcodes.ASTORE, localReturnObj) // consume objectref
                    }
                    mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref

                    // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
                    //                   | return value (objectref)
                    // Store the result value in a local variable (but keep it on the stack)
                    // Call the hook method
                    mv.visitMethodInsn(
                        Opcodes.INVOKESTATIC,
                        hook.hookInternalClassName,
                        hook.hookMethodName,
                        hook.hookMethodDescriptor,
                        false
                    )
                    // Stack layout: ...
                    // Push the return value on the stack after the last AFTER hook if the original method returns a value
                    if (index == matchingHooks.size - 1 && returnTypeDescriptor != "V") {
                        // Push the return value again
                        mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref
                        // Unwrap it, if it was a primitive value
                        unwrapTypeIfPrimitive(returnTypeDescriptor)
                        // Stack layout: ... | return value (primitive/objectref)
                    }
                }
            }
        }
    }

    private fun isMethodInvocationOp(opcode: Int) = opcode in listOf(
        Opcodes.INVOKEVIRTUAL,
        Opcodes.INVOKEINTERFACE,
        Opcodes.INVOKESTATIC,
        Opcodes.INVOKESPECIAL
    )

    private fun findMatchingHooks(owner: String, name: String, descriptor: String): List<Hook> {
        val result = HookType.values().flatMap { hookType ->
            val withoutDescriptorKey = "$hookType#$owner#$name"
            val withDescriptorKey = "$withoutDescriptorKey#$descriptor"
            hooks[withDescriptorKey].orEmpty() + hooks[withoutDescriptorKey].orEmpty()
        }.sortedBy { it.hookType }
        val replaceHookCount = result.count { it.hookType == HookType.REPLACE }
        check(
            replaceHookCount == 0 ||
                (replaceHookCount == 1 && result.size == 1)
        ) {
            "For a given method, You can either have a single REPLACE hook or BEFORE/AFTER hooks. Found:\n $result"
        }

        return result
            .filter { !isReplaceHookInJava6mode(it) }
            .sortedByDescending { it.toString() }
    }

    private fun isReplaceHookInJava6mode(hook: Hook): Boolean {
        if (java6Mode && hook.hookType == HookType.REPLACE) {
            if (showUnsupportedHookWarning.getAndSet(false)) {
                println(
                    """WARN: Some hooks could not be applied to class files built for Java 7 or lower.
                      |WARN: Ensure that the fuzz target and its dependencies are compiled with
                      |WARN: -target 8 or higher to identify as many bugs as possible.
            """.trimMargin()
                )
            }
            return true
        }
        return false
    }

    // Stores all arguments for a method call in a local object array.
    // paramDescriptors: The type descriptors for all method arguments
    private fun storeMethodArguments(paramDescriptors: List<String>): Int {
        // Allocate a new Object[] for the methods parameters.
        mv.visitIntInsn(Opcodes.SIPUSH, paramDescriptors.size)
        mv.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object")
        val localObjArr = lvs.newLocal(Type.getType("[Ljava/lang/Object;"))
        mv.visitVarInsn(Opcodes.ASTORE, localObjArr)

        // Loop over all arguments in reverse order (because the last argument is on top).
        for ((argIdx, argDescriptor) in paramDescriptors.withIndex().reversed()) {
            // If the argument is a primitive type, wrap it in it's wrapper class
            wrapTypeIfPrimitive(argDescriptor)
            // Store the argument in our object array, for that we need to shape the stack first.
            // Stack layout: ... | method argument (objectref)
            mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
            // Stack layout: ... | method argument (objectref) | object array (arrayref)
            mv.visitInsn(Opcodes.SWAP)
            // Stack layout: ... | object array (arrayref) | method argument (objectref)
            mv.visitIntInsn(Opcodes.SIPUSH, argIdx)
            // Stack layout: ... | object array (arrayref) | method argument (objectref) | argument index (int)
            mv.visitInsn(Opcodes.SWAP)
            // Stack layout: ... | object array (arrayref) | argument index (int) | method argument (objectref)
            mv.visitInsn(Opcodes.AASTORE) // consume all three: arrayref, index, value
            // Stack layout: ...
            // Continue with the remaining method arguments
        }

        // Return a reference to the array with the parameters.
        return localObjArr
    }

    // Loads all arguments for a method call from a local object array.
    // argTypeSigs: The type signatures for all method arguments
    // localObjArr: Index of a local variable containing an object array where the arguments will be loaded from
    private fun loadMethodArguments(paramDescriptors: List<String>, localObjArr: Int) {
        // Loop over all arguments
        for ((argIdx, argDescriptor) in paramDescriptors.withIndex()) {
            // Push a reference to the object array on the stack
            mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
            // Stack layout: ... | object array (arrayref)
            // Push the index of the current argument on the stack
            mv.visitIntInsn(Opcodes.SIPUSH, argIdx)
            // Stack layout: ... | object array (arrayref) | argument index (int)
            // Load the argument from the array
            mv.visitInsn(Opcodes.AALOAD)
            // Stack layout: ... | method argument (objectref)
            // Cast object to it's original type (or it's wrapper object)
            val wrapperTypeDescriptor = getWrapperTypeDescriptor(argDescriptor)
            mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(wrapperTypeDescriptor))
            // If the argument is a supposed to be a primitive type, unwrap the wrapped type
            unwrapTypeIfPrimitive(argDescriptor)
            // Stack layout: ... | method argument (primitive/objectref)
            // Continue with the remaining method arguments
        }
    }

    // Removes a primitive value from the top of the operand stack
    // and pushes it enclosed in its wrapper type (e.g. removes int, pushes Integer).
    // This is done by calling .valueOf(...) on the wrapper class.
    private fun wrapTypeIfPrimitive(unwrappedTypeDescriptor: String) {
        if (!isPrimitiveType(unwrappedTypeDescriptor) || unwrappedTypeDescriptor == "V") return
        val wrapperTypeDescriptor = getWrapperTypeDescriptor(unwrappedTypeDescriptor)
        val wrapperType = extractInternalClassName(wrapperTypeDescriptor)
        val valueOfDescriptor = "($unwrappedTypeDescriptor)$wrapperTypeDescriptor"
        mv.visitMethodInsn(Opcodes.INVOKESTATIC, wrapperType, "valueOf", valueOfDescriptor, false)
    }

    // Removes a wrapper object around a given primitive type from the top of the operand stack
    // and pushes the primitive value it contains (e.g. removes Integer, pushes int).
    // This is done by calling .intValue(...) / .charValue(...) / ... on the wrapper object.
    private fun unwrapTypeIfPrimitive(primitiveTypeDescriptor: String) {
        val (methodName, wrappedTypeDescriptor) = when (primitiveTypeDescriptor) {
            "B" -> Pair("byteValue", "java/lang/Byte")
            "C" -> Pair("charValue", "java/lang/Character")
            "D" -> Pair("doubleValue", "java/lang/Double")
            "F" -> Pair("floatValue", "java/lang/Float")
            "I" -> Pair("intValue", "java/lang/Integer")
            "J" -> Pair("longValue", "java/lang/Long")
            "S" -> Pair("shortValue", "java/lang/Short")
            "Z" -> Pair("booleanValue", "java/lang/Boolean")
            else -> return
        }
        mv.visitMethodInsn(
            Opcodes.INVOKEVIRTUAL,
            wrappedTypeDescriptor,
            methodName,
            "()$primitiveTypeDescriptor",
            false
        )
    }
}