1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
// Copyright 2000-2021 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package org.jetbrains.kotlin.idea.inspections
import com.intellij.codeInsight.actions.OptimizeImportsProcessor
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.descriptors.CallableDescriptor
import org.jetbrains.kotlin.idea.KotlinBundle
import org.jetbrains.kotlin.idea.core.ShortenReferences
import org.jetbrains.kotlin.idea.core.replaced
import org.jetbrains.kotlin.idea.caches.resolve.safeAnalyzeNonSourceRootCode
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.typeUtil.isSubtypeOf
class ReplaceAssertBooleanWithAssertEqualityInspection : AbstractApplicabilityBasedInspection<KtCallExpression>(
KtCallExpression::class.java
) {
override fun inspectionText(element: KtCallExpression) = KotlinBundle.message("replace.assert.boolean.with.assert.equality")
override val defaultFixText get() = KotlinBundle.message("replace.assert.boolean.with.assert.equality")
override fun fixText(element: KtCallExpression): String {
val assertion = element.replaceableAssertion() ?: return defaultFixText
return KotlinBundle.message("replace.with.0", assertion)
}
override fun isApplicable(element: KtCallExpression): Boolean {
return (element.replaceableAssertion() != null)
}
override fun applyTo(element: KtCallExpression, project: Project, editor: Editor?) {
val valueArguments = element.valueArguments
val condition = valueArguments.firstOrNull()?.getArgumentExpression() as? KtBinaryExpression ?: return
val left = condition.left ?: return
val right = condition.right ?: return
val assertion = element.replaceableAssertion() ?: return
val file = element.containingKtFile
val factory = KtPsiFactory(project)
val replaced = if (valueArguments.size == 2) {
val message = valueArguments[1].getArgumentExpression() ?: return
element.replaced(factory.createExpressionByPattern("$assertion($0, $1, $2)", left, right, message))
} else {
element.replaced(factory.createExpressionByPattern("$assertion($0, $1)", left, right))
}
ShortenReferences.DEFAULT.process(replaced)
OptimizeImportsProcessor(project, file).run()
}
private fun KtCallExpression.replaceableAssertion(): String? {
val referencedName = (calleeExpression as? KtNameReferenceExpression)?.getReferencedName() ?: return null
if (referencedName !in assertions) {
return null
}
val context = safeAnalyzeNonSourceRootCode(BodyResolveMode.PARTIAL)
if (descriptor(context)?.containingDeclaration?.fqNameSafe != FqName(kotlinTestPackage)) {
return null
}
if (valueArguments.size != 1 && valueArguments.size != 2) return null
val binaryExpression = valueArguments.first().getArgumentExpression() as? KtBinaryExpression ?: return null
val leftType = binaryExpression.left?.type(context) ?: return null
val rightType = binaryExpression.right?.type(context) ?: return null
if (!leftType.isSubtypeOf(rightType) && !rightType.isSubtypeOf(leftType)) return null
val operationToken = binaryExpression.operationToken
return assertionMap[Pair(referencedName, operationToken)]
}
private fun KtExpression.descriptor(context: BindingContext): CallableDescriptor? {
return getResolvedCall(context)?.resultingDescriptor
}
private fun KtExpression.type(context: BindingContext): KotlinType? {
return descriptor(context)?.returnType
}
companion object {
private const val kotlinTestPackage = "kotlin.test"
private val assertions = setOf("assertTrue", "assertFalse")
private val assertionMap = mapOf(
Pair("assertTrue", KtTokens.EQEQ) to "$kotlinTestPackage.assertEquals",
Pair("assertTrue", KtTokens.EQEQEQ) to "$kotlinTestPackage.assertSame",
Pair("assertFalse", KtTokens.EQEQ) to "$kotlinTestPackage.assertNotEquals",
Pair("assertFalse", KtTokens.EQEQEQ) to "$kotlinTestPackage.assertNotSame"
)
}
}
|