aboutsummaryrefslogtreecommitdiff
path: root/agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt
blob: 7c23c703b5a5ed3e62d24730413f715690dd8fc6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
// Copyright 2021 Code Intelligence GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.code_intelligence.jazzer.instrumentor

import com.code_intelligence.jazzer.api.HookType
import com.code_intelligence.jazzer.third_party.objectweb.asm.Handle
import com.code_intelligence.jazzer.third_party.objectweb.asm.MethodVisitor
import com.code_intelligence.jazzer.third_party.objectweb.asm.Opcodes
import com.code_intelligence.jazzer.third_party.objectweb.asm.Type
import com.code_intelligence.jazzer.third_party.objectweb.asm.commons.LocalVariablesSorter

internal fun makeHookMethodVisitor(
    access: Int,
    descriptor: String?,
    methodVisitor: MethodVisitor?,
    hooks: Iterable<Hook>,
    java6Mode: Boolean,
    random: DeterministicRandom,
): MethodVisitor {
    return HookMethodVisitor(access, descriptor, methodVisitor, hooks, java6Mode, random).lvs
}

private class HookMethodVisitor(
    access: Int,
    descriptor: String?,
    methodVisitor: MethodVisitor?,
    hooks: Iterable<Hook>,
    private val java6Mode: Boolean,
    private val random: DeterministicRandom,
) : MethodVisitor(Instrumentor.ASM_API_VERSION, methodVisitor) {

    val lvs = object : LocalVariablesSorter(Instrumentor.ASM_API_VERSION, access, descriptor, this) {
        override fun updateNewLocals(newLocals: Array<Any>) {
            // The local variables involved in calling hooks do not need to outlive the current
            // basic block and should thus not appear in stack map frames. By requesting the
            // LocalVariableSorter to fill their entries in stack map frames with TOP, they will
            // be treated like an unused local variable slot.
            newLocals.fill(Opcodes.TOP)
        }
    }

    private val hooks = hooks.associateBy { hook ->
        var hookKey = "${hook.hookType}#${hook.targetInternalClassName}#${hook.targetMethodName}"
        if (hook.targetMethodDescriptor != null)
            hookKey += "#${hook.targetMethodDescriptor}"
        hookKey
    }

    override fun visitMethodInsn(
        opcode: Int,
        owner: String,
        methodName: String,
        methodDescriptor: String,
        isInterface: Boolean,
    ) {
        if (!isMethodInvocationOp(opcode)) {
            mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
            return
        }
        handleMethodInsn(HookType.BEFORE, opcode, owner, methodName, methodDescriptor, isInterface)
    }

    /**
     * Emits the bytecode for a method call instruction for the next applicable hook type in order (BEFORE, REPLACE,
     * AFTER). Since the instrumented code is indistinguishable from an uninstrumented call instruction, it can be
     * safely nested. Combining REPLACE hooks with other hooks is however not supported as these hooks already subsume
     * the functionality of BEFORE and AFTER hooks.
     */
    private fun visitNextHookTypeOrCall(
        hookType: HookType,
        appliedHook: Boolean,
        opcode: Int,
        owner: String,
        methodName: String,
        methodDescriptor: String,
        isInterface: Boolean,
    ) = when (hookType) {
        HookType.BEFORE -> {
            val nextHookType = if (appliedHook) {
                // After a BEFORE hook has been applied, we can safely apply an AFTER hook by replacing the actual
                // call instruction with the full bytecode injected for the AFTER hook.
                HookType.AFTER
            } else {
                // If no BEFORE hook is registered, look for a REPLACE hook next.
                HookType.REPLACE
            }
            handleMethodInsn(nextHookType, opcode, owner, methodName, methodDescriptor, isInterface)
        }
        HookType.REPLACE -> {
            // REPLACE hooks can't (and don't need to) be mixed with other hooks. We only cycle through them if we
            // couldn't find a matching REPLACE hook, in which case we try an AFTER hook next.
            require(!appliedHook)
            handleMethodInsn(HookType.AFTER, opcode, owner, methodName, methodDescriptor, isInterface)
        }
        // An AFTER hook is always the last in the chain. Whether a hook has been applied or not, always emit the
        // actual call instruction.
        HookType.AFTER -> mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface)
    }

    fun handleMethodInsn(
        hookType: HookType,
        opcode: Int,
        owner: String,
        methodName: String,
        methodDescriptor: String,
        isInterface: Boolean,
    ) {
        val hook = findMatchingHook(hookType, owner, methodName, methodDescriptor)
        if (hook == null) {
            visitNextHookTypeOrCall(hookType, false, opcode, owner, methodName, methodDescriptor, isInterface)
            return
        }

        // The hookId is used to identify a call site.
        val hookId = random.nextInt()

        val paramDescriptors = extractParameterTypeDescriptors(methodDescriptor)
        val localObjArr = storeMethodArguments(paramDescriptors)
        // If the method we're hooking is not static there is now a reference to
        // the object the method was invoked on at the top of the stack.
        // If the method is static, that object is missing. We make up for it by pushing a null ref.
        if (opcode == Opcodes.INVOKESTATIC) {
            mv.visitInsn(Opcodes.ACONST_NULL)
        }

        // Save the owner object to a new local variable
        val ownerDescriptor = "L$owner;"
        val localOwnerObj = lvs.newLocal(Type.getType(ownerDescriptor))
        mv.visitVarInsn(Opcodes.ASTORE, localOwnerObj) // consume objectref
        // We now removed all values for the original method call from the operand stack
        // and saved them to local variables.

        // Start to build the arguments for the hook method.
        if (methodName == "<init>") {
            // Special case for constructors:
            // We cannot create a MethodHandle for a constructor, so we push null instead.
            mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
            // Only pass the this object if it has been initialized by the time the hook is invoked.
            if (hook.hookType == HookType.AFTER) {
                mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj)
            } else {
                mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
            }
        } else {
            // Push a MethodHandle representing the hooked method.
            val handleOpcode = when (opcode) {
                Opcodes.INVOKEVIRTUAL -> Opcodes.H_INVOKEVIRTUAL
                Opcodes.INVOKEINTERFACE -> Opcodes.H_INVOKEINTERFACE
                Opcodes.INVOKESTATIC -> Opcodes.H_INVOKESTATIC
                Opcodes.INVOKESPECIAL -> Opcodes.H_INVOKESPECIAL
                else -> -1
            }
            if (java6Mode) {
                // MethodHandle constants (type 15) are not supported in Java 6 class files (major version 50).
                mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
            } else {
                mv.visitLdcInsn(
                    Handle(
                        handleOpcode,
                        owner,
                        methodName,
                        methodDescriptor,
                        isInterface
                    )
                ) // push MethodHandle
            }
            // Stack layout: ... | MethodHandle (objectref)
            // Push the owner object again
            mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj)
        }
        // Stack layout: ... | MethodHandle (objectref) | owner (objectref)
        // Push a reference to our object array with the saved arguments
        mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
        // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref)
        // Push the hook id
        mv.visitLdcInsn(hookId)
        // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
        // How we proceed depends on the type of hook we want to implement
        when (hook.hookType) {
            HookType.BEFORE -> {
                // Call the hook method
                mv.visitMethodInsn(
                    Opcodes.INVOKESTATIC,
                    hook.hookInternalClassName,
                    hook.hookMethodName,
                    hook.hookMethodDescriptor,
                    false
                )
                // Stack layout: ...
                // Push the values for the original method call onto the stack again
                if (opcode != Opcodes.INVOKESTATIC) {
                    mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object
                }
                loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments
                // Stack layout: ... | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ...
                // Call the original method or the next hook in order.
                visitNextHookTypeOrCall(hookType, true, opcode, owner, methodName, methodDescriptor, isInterface)
            }
            HookType.REPLACE -> {
                // Call the hook method
                mv.visitMethodInsn(
                    Opcodes.INVOKESTATIC,
                    hook.hookInternalClassName,
                    hook.hookMethodName,
                    hook.hookMethodDescriptor,
                    false
                )
                // Stack layout: ... | [return value (primitive/objectref)]
                // Check if we need to process the return value
                val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor)
                if (returnTypeDescriptor != "V") {
                    val hookMethodReturnType = extractReturnTypeDescriptor(hook.hookMethodDescriptor)
                    // if the hook method's return type is primitive we don't need to unwrap or cast it
                    if (!isPrimitiveType(hookMethodReturnType)) {
                        // Check if the returned object type is different than the one that should be returned
                        // If a primitive should be returned we check it's wrapper type
                        val expectedType = getWrapperTypeDescriptor(returnTypeDescriptor)
                        if (expectedType != hookMethodReturnType) {
                            // Cast object
                            mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(expectedType))
                        }
                        // Check if we need to unwrap the returned object
                        unwrapTypeIfPrimitive(returnTypeDescriptor)
                    }
                }
            }
            HookType.AFTER -> {
                // Push the values for the original method call again onto the stack
                if (opcode != Opcodes.INVOKESTATIC) {
                    mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object
                }
                loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments
                // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
                //                   | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ...
                // Call the original method or the next hook in order
                visitNextHookTypeOrCall(hookType, true, opcode, owner, methodName, methodDescriptor, isInterface)
                val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor)
                if (returnTypeDescriptor == "V") {
                    // If the method didn't return anything, we push a nullref as placeholder
                    mv.visitInsn(Opcodes.ACONST_NULL) // push nullref
                }
                // Wrap return value if it is a primitive type
                wrapTypeIfPrimitive(returnTypeDescriptor)
                // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int)
                //                   | return value (objectref)
                // Store the result value in a local variable (but keep it on the stack)
                val localReturnObj = lvs.newLocal(Type.getType(getWrapperTypeDescriptor(returnTypeDescriptor)))
                mv.visitVarInsn(Opcodes.ASTORE, localReturnObj) // consume objectref
                mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref
                // Call the hook method
                mv.visitMethodInsn(
                    Opcodes.INVOKESTATIC,
                    hook.hookInternalClassName,
                    hook.hookMethodName,
                    hook.hookMethodDescriptor,
                    false
                )
                // Stack layout: ...
                if (returnTypeDescriptor != "V") {
                    // Push the return value again
                    mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref
                    // Unwrap it, if it was a primitive value
                    unwrapTypeIfPrimitive(returnTypeDescriptor)
                    // Stack layout: ... | return value (primitive/objectref)
                }
            }
        }
    }

    private fun isMethodInvocationOp(opcode: Int) = opcode in listOf(
        Opcodes.INVOKEVIRTUAL,
        Opcodes.INVOKEINTERFACE,
        Opcodes.INVOKESTATIC,
        Opcodes.INVOKESPECIAL
    )

    private fun findMatchingHook(hookType: HookType, owner: String, name: String, descriptor: String): Hook? {
        val withoutDescriptorKey = "$hookType#$owner#$name"
        val withDescriptorKey = "$withoutDescriptorKey#$descriptor"
        return hooks[withDescriptorKey] ?: hooks[withoutDescriptorKey]
    }

    // Stores all arguments for a method call in a local object array.
    // paramDescriptors: The type descriptors for all method arguments
    private fun storeMethodArguments(paramDescriptors: List<String>): Int {
        // Allocate a new Object[] for the methods parameters.
        mv.visitIntInsn(Opcodes.SIPUSH, paramDescriptors.size)
        mv.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object")
        val localObjArr = lvs.newLocal(Type.getType("[Ljava/lang/Object;"))
        mv.visitVarInsn(Opcodes.ASTORE, localObjArr)

        // Loop over all arguments in reverse order (because the last argument is on top).
        for ((argIdx, argDescriptor) in paramDescriptors.withIndex().reversed()) {
            // If the argument is a primitive type, wrap it in it's wrapper class
            wrapTypeIfPrimitive(argDescriptor)
            // Store the argument in our object array, for that we need to shape the stack first.
            // Stack layout: ... | method argument (objectref)
            mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
            // Stack layout: ... | method argument (objectref) | object array (arrayref)
            mv.visitInsn(Opcodes.SWAP)
            // Stack layout: ... | object array (arrayref) | method argument (objectref)
            mv.visitIntInsn(Opcodes.SIPUSH, argIdx)
            // Stack layout: ... | object array (arrayref) | method argument (objectref) | argument index (int)
            mv.visitInsn(Opcodes.SWAP)
            // Stack layout: ... | object array (arrayref) | argument index (int) | method argument (objectref)
            mv.visitInsn(Opcodes.AASTORE) // consume all three: arrayref, index, value
            // Stack layout: ...
            // Continue with the remaining method arguments
        }

        // Return a reference to the array with the parameters.
        return localObjArr
    }

    // Loads all arguments for a method call from a local object array.
    // argTypeSigs: The type signatures for all method arguments
    // localObjArr: Index of a local variable containing an object array where the arguments will be loaded from
    private fun loadMethodArguments(paramDescriptors: List<String>, localObjArr: Int) {
        // Loop over all arguments
        for ((argIdx, argDescriptor) in paramDescriptors.withIndex()) {
            // Push a reference to the object array on the stack
            mv.visitVarInsn(Opcodes.ALOAD, localObjArr)
            // Stack layout: ... | object array (arrayref)
            // Push the index of the current argument on the stack
            mv.visitIntInsn(Opcodes.SIPUSH, argIdx)
            // Stack layout: ... | object array (arrayref) | argument index (int)
            // Load the argument from the array
            mv.visitInsn(Opcodes.AALOAD)
            // Stack layout: ... | method argument (objectref)
            // Cast object to it's original type (or it's wrapper object)
            val wrapperTypeDescriptor = getWrapperTypeDescriptor(argDescriptor)
            mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(wrapperTypeDescriptor))
            // If the argument is a supposed to be a primitive type, unwrap the wrapped type
            unwrapTypeIfPrimitive(argDescriptor)
            // Stack layout: ... | method argument (primitive/objectref)
            // Continue with the remaining method arguments
        }
    }

    // Removes a primitive value from the top of the operand stack
    // and pushes it enclosed in it's wrapper type (e.g. removes int, pushes Integer).
    // This is done by calling .valueOf(...) on the wrapper class.
    private fun wrapTypeIfPrimitive(unwrappedTypeDescriptor: String) {
        if (!isPrimitiveType(unwrappedTypeDescriptor) || unwrappedTypeDescriptor == "V") return
        val wrapperTypeDescriptor = getWrapperTypeDescriptor(unwrappedTypeDescriptor)
        val wrapperType = extractInternalClassName(wrapperTypeDescriptor)
        val valueOfDescriptor = "($unwrappedTypeDescriptor)$wrapperTypeDescriptor"
        mv.visitMethodInsn(Opcodes.INVOKESTATIC, wrapperType, "valueOf", valueOfDescriptor, false)
    }

    // Removes a wrapper object around a given primitive type from the top of the operand stack
    // and pushes the primitive value it contains (e.g. removes Integer, pushes int).
    // This is done by calling .intValue(...) / .charValue(...) / ... on the wrapper object.
    private fun unwrapTypeIfPrimitive(primitiveTypeDescriptor: String) {
        val (methodName, wrappedTypeDescriptor) = when (primitiveTypeDescriptor) {
            "B" -> Pair("byteValue", "java/lang/Byte")
            "C" -> Pair("charValue", "java/lang/Character")
            "D" -> Pair("doubleValue", "java/lang/Double")
            "F" -> Pair("floatValue", "java/lang/Float")
            "I" -> Pair("intValue", "java/lang/Integer")
            "J" -> Pair("longValue", "java/lang/Long")
            "S" -> Pair("shortValue", "java/lang/Short")
            "Z" -> Pair("booleanValue", "java/lang/Boolean")
            else -> return
        }
        mv.visitMethodInsn(
            Opcodes.INVOKEVIRTUAL,
            wrappedTypeDescriptor,
            methodName,
            "()$primitiveTypeDescriptor",
            false
        )
    }
}