diff options
Diffstat (limited to 'src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java')
-rw-r--r-- | src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java | 598 |
1 files changed, 598 insertions, 0 deletions
diff --git a/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java new file mode 100644 index 00000000..7fb15866 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java @@ -0,0 +1,598 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.runtime; + +import com.code_intelligence.jazzer.api.HookType; +import com.code_intelligence.jazzer.api.MethodHook; +import java.lang.invoke.MethodHandle; +import java.util.*; + +@SuppressWarnings("unused") +final public class TraceCmpHooks { + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", targetMethod = "compare", + targetMethodDescriptor = "(BB)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", + targetMethod = "compareUnsigned", targetMethodDescriptor = "(BB)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", targetMethod = "compare", + targetMethodDescriptor = "(SS)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", + targetMethod = "compareUnsigned", targetMethodDescriptor = "(SS)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer", + targetMethod = "compare", targetMethodDescriptor = "(II)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer", + targetMethod = "compareUnsigned", targetMethodDescriptor = "(II)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ", + targetMethod = "compare", targetMethodDescriptor = "(II)I") + public static void + integerCompare(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpInt((int) arguments[0], (int) arguments[1], hookId); + } + + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", + targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Byte;)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", + targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Short;)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer", + targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Integer;)I") + public static void + integerCompareTo(MethodHandle method, Object thisObject, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpInt((int) thisObject, (int) arguments[0], hookId); + } + + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", targetMethod = "compare", + targetMethodDescriptor = "(JJ)I") + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", + targetMethod = "compareUnsigned", targetMethodDescriptor = "(JJ)I") + public static void + longCompare(MethodHandle method, Object thisObject, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId); + } + + @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", + targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Long;)I") + public static void + longCompareTo(MethodHandle method, Long thisObject, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpLong(thisObject, (long) arguments[0], hookId); + } + + @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ", + targetMethod = "compare", targetMethodDescriptor = "(JJ)I") + public static void + longCompareKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) { + TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId); + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "equals") + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", + targetMethod = "equalsIgnoreCase") + public static void + equals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqual) { + if (!areEqual && arguments.length == 1 && arguments[0] instanceof String) { + // The precise value of the result of the comparison is not used by libFuzzer as long as it is + // non-zero. + TraceDataFlowNativeCallbacks.traceStrcmp(thisObject, (String) arguments[0], 1, hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Object", targetMethod = "equals") + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.CharSequence", targetMethod = "equals") + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Number", targetMethod = "equals") + public static void + genericEquals( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) { + if (!areEqual && arguments.length == 1 && arguments[0] != null + && thisObject.getClass() == arguments[0].getClass()) { + TraceDataFlowNativeCallbacks.traceGenericCmp(thisObject, arguments[0], hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "clojure.lang.Util", targetMethod = "equiv") + public static void genericStaticEquals( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) { + if (!areEqual && arguments.length == 2 && arguments[0] != null && arguments[1] != null + && arguments[1].getClass() == arguments[0].getClass()) { + TraceDataFlowNativeCallbacks.traceGenericCmp(arguments[0], arguments[1], hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "compareTo") + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", + targetMethod = "compareToIgnoreCase") + public static void + compareTo( + MethodHandle method, String thisObject, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue != 0 && arguments.length == 1 && arguments[0] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrcmp( + thisObject, (String) arguments[0], returnValue, hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contentEquals") + public static void + contentEquals(MethodHandle method, String thisObject, Object[] arguments, int hookId, + Boolean areEqualContents) { + if (!areEqualContents && arguments.length == 1 && arguments[0] instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrcmp( + thisObject, ((CharSequence) arguments[0]).toString(), 1, hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", + targetMethod = "regionMatches", targetMethodDescriptor = "(ZILjava/lang/String;II)Z") + public static void + regionsMatches5( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) { + if (!returnValue) { + int toffset = (int) arguments[1]; + String other = (String) arguments[2]; + int ooffset = (int) arguments[3]; + int len = (int) arguments[4]; + regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", + targetMethod = "regionMatches", targetMethodDescriptor = "(ILjava/lang/String;II)Z") + public static void + regionMatches4( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) { + if (!returnValue) { + int toffset = (int) arguments[0]; + String other = (String) arguments[1]; + int ooffset = (int) arguments[2]; + int len = (int) arguments[3]; + regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId); + } + } + + private static void regionMatchesInternal( + String thisString, int toffset, String other, int ooffset, int len, int hookId) { + if (toffset < 0 || ooffset < 0) + return; + int cappedThisStringEnd = Math.min(toffset + len, thisString.length()); + int cappedOtherStringEnd = Math.min(ooffset + len, other.length()); + String thisPart = thisString.substring(toffset, cappedThisStringEnd); + String otherPart = other.substring(ooffset, cappedOtherStringEnd); + TraceDataFlowNativeCallbacks.traceStrcmp(thisPart, otherPart, 1, hookId); + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contains") + public static void + contains( + MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesContain) { + if (!doesContain && arguments.length == 1 && arguments[0] instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrstr( + thisObject, ((CharSequence) arguments[0]).toString(), hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "indexOf") + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "lastIndexOf") + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.StringBuffer", targetMethod = "indexOf") + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuffer", + targetMethod = "lastIndexOf") + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.StringBuilder", targetMethod = "indexOf") + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuilder", + targetMethod = "lastIndexOf") + public static void + indexOf( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue == -1 && arguments.length >= 1 && arguments[0] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrstr( + thisObject.toString(), (String) arguments[0], hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "startsWith") + @MethodHook( + type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "endsWith") + public static void + startsWith(MethodHandle method, String thisObject, Object[] arguments, int hookId, + Boolean doesStartOrEndsWith) { + if (!doesStartOrEndsWith && arguments.length >= 1 && arguments[0] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrstr(thisObject, (String) arguments[0], hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "replace", + targetMethodDescriptor = + "(Ljava/lang/CharSequence;Ljava/lang/CharSequence;)Ljava/lang/String;") + public static void + replace( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, String returnValue) { + String original = (String) thisObject; + // Report only if the replacement was not successful. + if (original.equals(returnValue)) { + String target = arguments[0].toString(); + TraceDataFlowNativeCallbacks.traceStrstr(original, target, hookId); + } + } + + // For standard Kotlin packages, which are named according to the pattern kotlin.*, we append a + // whitespace to the package name of the target class so that they are not mangled due to shading. + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.jvm.internal.Intrinsics ", + targetMethod = "areEqual") + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "equals") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "equals$default") + public static void + equalsKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, + Boolean equalStrings) { + if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof String + && arguments[1] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrcmp( + (String) arguments[0], (String) arguments[1], 1, hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "contentEquals") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "contentEquals$default") + public static void + contentEqualKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, + Boolean equalStrings) { + if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof CharSequence + && arguments[1] instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrcmp( + arguments[0].toString(), arguments[1].toString(), 1, hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "compareTo") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "compareTo$default") + public static void + compareToKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue != 0 && arguments.length >= 2 && arguments[0] instanceof String + && arguments[1] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrcmp( + (String) arguments[0], (String) arguments[1], 1, hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "endsWith") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "endsWith$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "startsWith") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "startsWith$default") + public static void + startsWithKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, + Boolean doesStartOrEndsWith) { + if (!doesStartOrEndsWith && arguments.length >= 2 && arguments[0] instanceof CharSequence + && arguments[1] instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrstr( + arguments[0].toString(), arguments[1].toString(), hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "contains") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "contains$default") + public static void + containsKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesContain) { + if (!doesContain && arguments.length >= 2 && arguments[0] instanceof CharSequence + && arguments[1] instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrstr( + arguments[0].toString(), arguments[1].toString(), hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "indexOf") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "indexOf$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "lastIndexOf") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "lastIndexOf$default") + public static void + indexOfKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue != -1 || arguments.length < 2 || !(arguments[0] instanceof CharSequence)) { + return; + } + if (arguments[1] instanceof String) { + TraceDataFlowNativeCallbacks.traceStrstr( + arguments[0].toString(), (String) arguments[1], hookId); + } else if (arguments[1] instanceof Character) { + TraceDataFlowNativeCallbacks.traceStrstr( + arguments[0].toString(), ((Character) arguments[1]).toString(), hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "replace") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replace$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceAfter") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceAfter$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceAfterLast") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceAfterLast$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceBefore") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceBefore$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceBeforeLast") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceBeforeLast$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceFirst") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "replaceFirst$default") + public static void + replaceKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, String returnValue) { + if (arguments.length < 2 || !(arguments[0] instanceof String)) { + return; + } + String original = (String) arguments[0]; + if (!original.equals(returnValue)) { + return; + } + + // We currently don't handle the overloads that take a regex as a second argument. + if (arguments[1] instanceof String || arguments[1] instanceof Character) { + TraceDataFlowNativeCallbacks.traceStrstr(original, arguments[1].toString(), hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "regionMatches", + targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZ)Z") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "regionMatches$default", + targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZILjava/lang/Object;)Z") + public static void + regionMatchesKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, + Boolean doesRegionMatch) { + if (!doesRegionMatch) { + String thisString = arguments[0].toString(); + int thisOffset = (int) arguments[1]; + String other = arguments[2].toString(); + int otherOffset = (int) arguments[3]; + int length = (int) arguments[4]; + regionMatchesInternal(thisString, thisOffset, other, otherOffset, length, hookId); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "indexOfAny") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "indexOfAny$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "lastIndexOfAny") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "lastIndexOfAny$default") + public static void + indexOfAnyKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue == -1 && arguments.length >= 2 && arguments[0] instanceof CharSequence) { + guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId); + } + } + + @MethodHook( + type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "findAnyOf") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "findAnyOf$default") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "findLastAnyOf") + @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", + targetMethod = "findLastAnyOf$default") + public static void + findAnyKt( + MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Object returnValue) { + if (returnValue == null && arguments.length >= 2 && arguments[0] instanceof CharSequence) { + guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId); + } + } + + private static void guideTowardContainmentOfFirstElement( + String containingString, Object candidateCollectionObj, int hookId) { + if (candidateCollectionObj instanceof Collection<?>) { + Collection<?> strings = (Collection<?>) candidateCollectionObj; + if (strings.isEmpty()) { + return; + } + Object firstElementObj = strings.iterator().next(); + if (firstElementObj instanceof CharSequence) { + TraceDataFlowNativeCallbacks.traceStrstr( + containingString, firstElementObj.toString(), hookId); + } + } else if (candidateCollectionObj.getClass().isArray()) { + if (candidateCollectionObj.getClass().getComponentType() == char.class) { + char[] chars = (char[]) candidateCollectionObj; + if (chars.length > 0) { + TraceDataFlowNativeCallbacks.traceStrstr( + containingString, Character.toString(chars[0]), hookId); + } + } + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals", + targetMethodDescriptor = "([B[B)Z") + public static void + arraysEquals( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) { + if (returnValue) + return; + byte[] first = (byte[]) arguments[0]; + byte[] second = (byte[]) arguments[1]; + TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId); + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals", + targetMethodDescriptor = "([BII[BII)Z") + public static void + arraysEqualsRange( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) { + if (returnValue) + return; + byte[] first = + Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]); + byte[] second = + Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]); + TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId); + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare", + targetMethodDescriptor = "([B[B)I") + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", + targetMethod = "compareUnsigned", targetMethodDescriptor = "([B[B)I") + public static void + arraysCompare( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue == 0) + return; + byte[] first = (byte[]) arguments[0]; + byte[] second = (byte[]) arguments[1]; + TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId); + } + + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare", + targetMethodDescriptor = "([BII[BII)I") + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", + targetMethod = "compareUnsigned", targetMethodDescriptor = "([BII[BII)I") + public static void + arraysCompareRange( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) { + if (returnValue == 0) + return; + byte[] first = + Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]); + byte[] second = + Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]); + TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId); + } + + // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the + // key closest to the current lookup key in the mapGet hook. + private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100; + + @SuppressWarnings({"rawtypes", "unchecked"}) + @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Map", targetMethod = "get") + public static void mapGet( + MethodHandle method, Object thisObject, Object[] arguments, int hookId, Object returnValue) { + if (returnValue != null) + return; + if (arguments.length != 1) { + return; + } + if (thisObject == null) + return; + final Map map = (Map) thisObject; + if (map.size() == 0) + return; + final Object currentKey = arguments[0]; + if (currentKey == null) + return; + // Find two valid map keys that bracket currentKey. + // This is a generalization of libFuzzer's __sanitizer_cov_trace_switch: + // https://github.com/llvm/llvm-project/blob/318942de229beb3b2587df09e776a50327b5cef0/compiler-rt/lib/fuzzer/FuzzerTracePC.cpp#L564 + Object lowerBoundKey = null; + Object upperBoundKey = null; + try { + if (map instanceof TreeMap) { + final TreeMap treeMap = (TreeMap) map; + try { + lowerBoundKey = treeMap.floorKey(currentKey); + upperBoundKey = treeMap.ceilingKey(currentKey); + } catch (ClassCastException ignored) { + // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be + // compared to the maps keys. + } + } else if (currentKey instanceof Comparable) { + final Comparable comparableCurrentKey = (Comparable) currentKey; + // Find two keys that bracket currentKey. + // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE. + int enumeratedKeys = 0; + for (Object validKey : map.keySet()) { + if (!(validKey instanceof Comparable)) + continue; + final Comparable comparableValidKey = (Comparable) validKey; + // If the key sorts lower than the non-existing key, but higher than the current lower + // bound, update the lower bound and vice versa for the upper bound. + try { + if (comparableValidKey.compareTo(comparableCurrentKey) < 0 + && (lowerBoundKey == null || comparableValidKey.compareTo(lowerBoundKey) > 0)) { + lowerBoundKey = validKey; + } + if (comparableValidKey.compareTo(comparableCurrentKey) > 0 + && (upperBoundKey == null || comparableValidKey.compareTo(upperBoundKey) < 0)) { + upperBoundKey = validKey; + } + } catch (ClassCastException ignored) { + // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be + // compared to the maps keys. + } + if (enumeratedKeys++ > MAX_NUM_KEYS_TO_ENUMERATE) + break; + } + } + } catch (ConcurrentModificationException ignored) { + // map was modified by another thread, skip this invocation + return; + } + // Modify the hook ID so that compares against distinct valid keys are traced separately. + if (lowerBoundKey != null) { + TraceDataFlowNativeCallbacks.traceGenericCmp( + currentKey, lowerBoundKey, hookId + lowerBoundKey.hashCode()); + } + if (upperBoundKey != null) { + TraceDataFlowNativeCallbacks.traceGenericCmp( + currentKey, upperBoundKey, hookId + upperBoundKey.hashCode()); + } + } + + @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions", + targetMethod = "assertNotEquals", + targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;)V") + @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions", + targetMethod = "assertNotEquals", + targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/String;)V") + @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions", + targetMethod = "assertNotEquals", + targetMethodDescriptor = + "(Ljava/lang/Object;Ljava/lang/Object;Ljava/util/function/Supplier;)V") + public static void + assertEquals(MethodHandle method, Object node, Object[] args, int hookId, Object alwaysNull) { + if (args[0] != null && args[1] != null && args[0].getClass() == args[1].getClass()) { + TraceDataFlowNativeCallbacks.traceGenericCmp(args[0], args[1], hookId); + } + } +} |