aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java')
-rw-r--r--src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java598
1 files changed, 598 insertions, 0 deletions
diff --git a/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java
new file mode 100644
index 00000000..7fb15866
--- /dev/null
+++ b/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java
@@ -0,0 +1,598 @@
+// Copyright 2021 Code Intelligence GmbH
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.code_intelligence.jazzer.runtime;
+
+import com.code_intelligence.jazzer.api.HookType;
+import com.code_intelligence.jazzer.api.MethodHook;
+import java.lang.invoke.MethodHandle;
+import java.util.*;
+
+@SuppressWarnings("unused")
+final public class TraceCmpHooks {
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", targetMethod = "compare",
+ targetMethodDescriptor = "(BB)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "(BB)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", targetMethod = "compare",
+ targetMethodDescriptor = "(SS)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "(SS)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
+ targetMethod = "compare", targetMethodDescriptor = "(II)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "(II)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
+ targetMethod = "compare", targetMethodDescriptor = "(II)I")
+ public static void
+ integerCompare(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
+ TraceDataFlowNativeCallbacks.traceCmpInt((int) arguments[0], (int) arguments[1], hookId);
+ }
+
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
+ targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Byte;)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
+ targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Short;)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
+ targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Integer;)I")
+ public static void
+ integerCompareTo(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
+ TraceDataFlowNativeCallbacks.traceCmpInt((int) thisObject, (int) arguments[0], hookId);
+ }
+
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", targetMethod = "compare",
+ targetMethodDescriptor = "(JJ)I")
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "(JJ)I")
+ public static void
+ longCompare(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
+ TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
+ }
+
+ @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
+ targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Long;)I")
+ public static void
+ longCompareTo(MethodHandle method, Long thisObject, Object[] arguments, int hookId) {
+ TraceDataFlowNativeCallbacks.traceCmpLong(thisObject, (long) arguments[0], hookId);
+ }
+
+ @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
+ targetMethod = "compare", targetMethodDescriptor = "(JJ)I")
+ public static void
+ longCompareKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
+ TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "equals")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
+ targetMethod = "equalsIgnoreCase")
+ public static void
+ equals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqual) {
+ if (!areEqual && arguments.length == 1 && arguments[0] instanceof String) {
+ // The precise value of the result of the comparison is not used by libFuzzer as long as it is
+ // non-zero.
+ TraceDataFlowNativeCallbacks.traceStrcmp(thisObject, (String) arguments[0], 1, hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Object", targetMethod = "equals")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.CharSequence", targetMethod = "equals")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Number", targetMethod = "equals")
+ public static void
+ genericEquals(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
+ if (!areEqual && arguments.length == 1 && arguments[0] != null
+ && thisObject.getClass() == arguments[0].getClass()) {
+ TraceDataFlowNativeCallbacks.traceGenericCmp(thisObject, arguments[0], hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "clojure.lang.Util", targetMethod = "equiv")
+ public static void genericStaticEquals(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
+ if (!areEqual && arguments.length == 2 && arguments[0] != null && arguments[1] != null
+ && arguments[1].getClass() == arguments[0].getClass()) {
+ TraceDataFlowNativeCallbacks.traceGenericCmp(arguments[0], arguments[1], hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "compareTo")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
+ targetMethod = "compareToIgnoreCase")
+ public static void
+ compareTo(
+ MethodHandle method, String thisObject, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue != 0 && arguments.length == 1 && arguments[0] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrcmp(
+ thisObject, (String) arguments[0], returnValue, hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contentEquals")
+ public static void
+ contentEquals(MethodHandle method, String thisObject, Object[] arguments, int hookId,
+ Boolean areEqualContents) {
+ if (!areEqualContents && arguments.length == 1 && arguments[0] instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrcmp(
+ thisObject, ((CharSequence) arguments[0]).toString(), 1, hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
+ targetMethod = "regionMatches", targetMethodDescriptor = "(ZILjava/lang/String;II)Z")
+ public static void
+ regionsMatches5(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
+ if (!returnValue) {
+ int toffset = (int) arguments[1];
+ String other = (String) arguments[2];
+ int ooffset = (int) arguments[3];
+ int len = (int) arguments[4];
+ regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
+ targetMethod = "regionMatches", targetMethodDescriptor = "(ILjava/lang/String;II)Z")
+ public static void
+ regionMatches4(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
+ if (!returnValue) {
+ int toffset = (int) arguments[0];
+ String other = (String) arguments[1];
+ int ooffset = (int) arguments[2];
+ int len = (int) arguments[3];
+ regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
+ }
+ }
+
+ private static void regionMatchesInternal(
+ String thisString, int toffset, String other, int ooffset, int len, int hookId) {
+ if (toffset < 0 || ooffset < 0)
+ return;
+ int cappedThisStringEnd = Math.min(toffset + len, thisString.length());
+ int cappedOtherStringEnd = Math.min(ooffset + len, other.length());
+ String thisPart = thisString.substring(toffset, cappedThisStringEnd);
+ String otherPart = other.substring(ooffset, cappedOtherStringEnd);
+ TraceDataFlowNativeCallbacks.traceStrcmp(thisPart, otherPart, 1, hookId);
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contains")
+ public static void
+ contains(
+ MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesContain) {
+ if (!doesContain && arguments.length == 1 && arguments[0] instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ thisObject, ((CharSequence) arguments[0]).toString(), hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "indexOf")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "lastIndexOf")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.StringBuffer", targetMethod = "indexOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuffer",
+ targetMethod = "lastIndexOf")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.StringBuilder", targetMethod = "indexOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuilder",
+ targetMethod = "lastIndexOf")
+ public static void
+ indexOf(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue == -1 && arguments.length >= 1 && arguments[0] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ thisObject.toString(), (String) arguments[0], hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "startsWith")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "endsWith")
+ public static void
+ startsWith(MethodHandle method, String thisObject, Object[] arguments, int hookId,
+ Boolean doesStartOrEndsWith) {
+ if (!doesStartOrEndsWith && arguments.length >= 1 && arguments[0] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrstr(thisObject, (String) arguments[0], hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "replace",
+ targetMethodDescriptor =
+ "(Ljava/lang/CharSequence;Ljava/lang/CharSequence;)Ljava/lang/String;")
+ public static void
+ replace(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, String returnValue) {
+ String original = (String) thisObject;
+ // Report only if the replacement was not successful.
+ if (original.equals(returnValue)) {
+ String target = arguments[0].toString();
+ TraceDataFlowNativeCallbacks.traceStrstr(original, target, hookId);
+ }
+ }
+
+ // For standard Kotlin packages, which are named according to the pattern kotlin.*, we append a
+ // whitespace to the package name of the target class so that they are not mangled due to shading.
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.jvm.internal.Intrinsics ",
+ targetMethod = "areEqual")
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "equals")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "equals$default")
+ public static void
+ equalsKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
+ Boolean equalStrings) {
+ if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof String
+ && arguments[1] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrcmp(
+ (String) arguments[0], (String) arguments[1], 1, hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "contentEquals")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "contentEquals$default")
+ public static void
+ contentEqualKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
+ Boolean equalStrings) {
+ if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof CharSequence
+ && arguments[1] instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrcmp(
+ arguments[0].toString(), arguments[1].toString(), 1, hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "compareTo")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "compareTo$default")
+ public static void
+ compareToKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue != 0 && arguments.length >= 2 && arguments[0] instanceof String
+ && arguments[1] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrcmp(
+ (String) arguments[0], (String) arguments[1], 1, hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "endsWith")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "endsWith$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "startsWith")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "startsWith$default")
+ public static void
+ startsWithKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
+ Boolean doesStartOrEndsWith) {
+ if (!doesStartOrEndsWith && arguments.length >= 2 && arguments[0] instanceof CharSequence
+ && arguments[1] instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ arguments[0].toString(), arguments[1].toString(), hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "contains")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "contains$default")
+ public static void
+ containsKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesContain) {
+ if (!doesContain && arguments.length >= 2 && arguments[0] instanceof CharSequence
+ && arguments[1] instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ arguments[0].toString(), arguments[1].toString(), hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "indexOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "indexOf$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "lastIndexOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "lastIndexOf$default")
+ public static void
+ indexOfKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue != -1 || arguments.length < 2 || !(arguments[0] instanceof CharSequence)) {
+ return;
+ }
+ if (arguments[1] instanceof String) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ arguments[0].toString(), (String) arguments[1], hookId);
+ } else if (arguments[1] instanceof Character) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ arguments[0].toString(), ((Character) arguments[1]).toString(), hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "replace")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replace$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceAfter")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceAfter$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceAfterLast")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceAfterLast$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceBefore")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceBefore$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceBeforeLast")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceBeforeLast$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceFirst")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "replaceFirst$default")
+ public static void
+ replaceKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, String returnValue) {
+ if (arguments.length < 2 || !(arguments[0] instanceof String)) {
+ return;
+ }
+ String original = (String) arguments[0];
+ if (!original.equals(returnValue)) {
+ return;
+ }
+
+ // We currently don't handle the overloads that take a regex as a second argument.
+ if (arguments[1] instanceof String || arguments[1] instanceof Character) {
+ TraceDataFlowNativeCallbacks.traceStrstr(original, arguments[1].toString(), hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "regionMatches",
+ targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZ)Z")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "regionMatches$default",
+ targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZILjava/lang/Object;)Z")
+ public static void
+ regionMatchesKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
+ Boolean doesRegionMatch) {
+ if (!doesRegionMatch) {
+ String thisString = arguments[0].toString();
+ int thisOffset = (int) arguments[1];
+ String other = arguments[2].toString();
+ int otherOffset = (int) arguments[3];
+ int length = (int) arguments[4];
+ regionMatchesInternal(thisString, thisOffset, other, otherOffset, length, hookId);
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "indexOfAny")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "indexOfAny$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "lastIndexOfAny")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "lastIndexOfAny$default")
+ public static void
+ indexOfAnyKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue == -1 && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
+ guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
+ }
+ }
+
+ @MethodHook(
+ type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "findAnyOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "findAnyOf$default")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "findLastAnyOf")
+ @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
+ targetMethod = "findLastAnyOf$default")
+ public static void
+ findAnyKt(
+ MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Object returnValue) {
+ if (returnValue == null && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
+ guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
+ }
+ }
+
+ private static void guideTowardContainmentOfFirstElement(
+ String containingString, Object candidateCollectionObj, int hookId) {
+ if (candidateCollectionObj instanceof Collection<?>) {
+ Collection<?> strings = (Collection<?>) candidateCollectionObj;
+ if (strings.isEmpty()) {
+ return;
+ }
+ Object firstElementObj = strings.iterator().next();
+ if (firstElementObj instanceof CharSequence) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ containingString, firstElementObj.toString(), hookId);
+ }
+ } else if (candidateCollectionObj.getClass().isArray()) {
+ if (candidateCollectionObj.getClass().getComponentType() == char.class) {
+ char[] chars = (char[]) candidateCollectionObj;
+ if (chars.length > 0) {
+ TraceDataFlowNativeCallbacks.traceStrstr(
+ containingString, Character.toString(chars[0]), hookId);
+ }
+ }
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
+ targetMethodDescriptor = "([B[B)Z")
+ public static void
+ arraysEquals(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
+ if (returnValue)
+ return;
+ byte[] first = (byte[]) arguments[0];
+ byte[] second = (byte[]) arguments[1];
+ TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
+ targetMethodDescriptor = "([BII[BII)Z")
+ public static void
+ arraysEqualsRange(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
+ if (returnValue)
+ return;
+ byte[] first =
+ Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
+ byte[] second =
+ Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
+ TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
+ targetMethodDescriptor = "([B[B)I")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "([B[B)I")
+ public static void
+ arraysCompare(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue == 0)
+ return;
+ byte[] first = (byte[]) arguments[0];
+ byte[] second = (byte[]) arguments[1];
+ TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
+ targetMethodDescriptor = "([BII[BII)I")
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
+ targetMethod = "compareUnsigned", targetMethodDescriptor = "([BII[BII)I")
+ public static void
+ arraysCompareRange(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
+ if (returnValue == 0)
+ return;
+ byte[] first =
+ Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
+ byte[] second =
+ Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
+ TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
+ }
+
+ // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
+ // key closest to the current lookup key in the mapGet hook.
+ private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100;
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Map", targetMethod = "get")
+ public static void mapGet(
+ MethodHandle method, Object thisObject, Object[] arguments, int hookId, Object returnValue) {
+ if (returnValue != null)
+ return;
+ if (arguments.length != 1) {
+ return;
+ }
+ if (thisObject == null)
+ return;
+ final Map map = (Map) thisObject;
+ if (map.size() == 0)
+ return;
+ final Object currentKey = arguments[0];
+ if (currentKey == null)
+ return;
+ // Find two valid map keys that bracket currentKey.
+ // This is a generalization of libFuzzer's __sanitizer_cov_trace_switch:
+ // https://github.com/llvm/llvm-project/blob/318942de229beb3b2587df09e776a50327b5cef0/compiler-rt/lib/fuzzer/FuzzerTracePC.cpp#L564
+ Object lowerBoundKey = null;
+ Object upperBoundKey = null;
+ try {
+ if (map instanceof TreeMap) {
+ final TreeMap treeMap = (TreeMap) map;
+ try {
+ lowerBoundKey = treeMap.floorKey(currentKey);
+ upperBoundKey = treeMap.ceilingKey(currentKey);
+ } catch (ClassCastException ignored) {
+ // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
+ // compared to the maps keys.
+ }
+ } else if (currentKey instanceof Comparable) {
+ final Comparable comparableCurrentKey = (Comparable) currentKey;
+ // Find two keys that bracket currentKey.
+ // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
+ int enumeratedKeys = 0;
+ for (Object validKey : map.keySet()) {
+ if (!(validKey instanceof Comparable))
+ continue;
+ final Comparable comparableValidKey = (Comparable) validKey;
+ // If the key sorts lower than the non-existing key, but higher than the current lower
+ // bound, update the lower bound and vice versa for the upper bound.
+ try {
+ if (comparableValidKey.compareTo(comparableCurrentKey) < 0
+ && (lowerBoundKey == null || comparableValidKey.compareTo(lowerBoundKey) > 0)) {
+ lowerBoundKey = validKey;
+ }
+ if (comparableValidKey.compareTo(comparableCurrentKey) > 0
+ && (upperBoundKey == null || comparableValidKey.compareTo(upperBoundKey) < 0)) {
+ upperBoundKey = validKey;
+ }
+ } catch (ClassCastException ignored) {
+ // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
+ // compared to the maps keys.
+ }
+ if (enumeratedKeys++ > MAX_NUM_KEYS_TO_ENUMERATE)
+ break;
+ }
+ }
+ } catch (ConcurrentModificationException ignored) {
+ // map was modified by another thread, skip this invocation
+ return;
+ }
+ // Modify the hook ID so that compares against distinct valid keys are traced separately.
+ if (lowerBoundKey != null) {
+ TraceDataFlowNativeCallbacks.traceGenericCmp(
+ currentKey, lowerBoundKey, hookId + lowerBoundKey.hashCode());
+ }
+ if (upperBoundKey != null) {
+ TraceDataFlowNativeCallbacks.traceGenericCmp(
+ currentKey, upperBoundKey, hookId + upperBoundKey.hashCode());
+ }
+ }
+
+ @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
+ targetMethod = "assertNotEquals",
+ targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;)V")
+ @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
+ targetMethod = "assertNotEquals",
+ targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/String;)V")
+ @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
+ targetMethod = "assertNotEquals",
+ targetMethodDescriptor =
+ "(Ljava/lang/Object;Ljava/lang/Object;Ljava/util/function/Supplier;)V")
+ public static void
+ assertEquals(MethodHandle method, Object node, Object[] args, int hookId, Object alwaysNull) {
+ if (args[0] != null && args[1] != null && args[0].getClass() == args[1].getClass()) {
+ TraceDataFlowNativeCallbacks.traceGenericCmp(args[0], args[1], hookId);
+ }
+ }
+}