summaryrefslogtreecommitdiff
path: root/plugins/kotlin/idea/src/org/jetbrains/kotlin/idea/inspections/ReplaceAssertBooleanWithAssertEqualityInspection.kt
blob: dc20fcee2e459c8b8d183dc5eae1ffcaac12156d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright 2000-2021 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.

package org.jetbrains.kotlin.idea.inspections

import com.intellij.codeInsight.actions.OptimizeImportsProcessor
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import org.jetbrains.kotlin.descriptors.CallableDescriptor
import org.jetbrains.kotlin.idea.KotlinBundle
import org.jetbrains.kotlin.idea.core.ShortenReferences
import org.jetbrains.kotlin.idea.core.replaced
import org.jetbrains.kotlin.idea.caches.resolve.safeAnalyzeNonSourceRootCode
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.typeUtil.isSubtypeOf

class ReplaceAssertBooleanWithAssertEqualityInspection : AbstractApplicabilityBasedInspection<KtCallExpression>(
    KtCallExpression::class.java
) {
    override fun inspectionText(element: KtCallExpression) = KotlinBundle.message("replace.assert.boolean.with.assert.equality")

    override val defaultFixText get() = KotlinBundle.message("replace.assert.boolean.with.assert.equality")

    override fun fixText(element: KtCallExpression): String {
        val assertion = element.replaceableAssertion() ?: return defaultFixText
        return KotlinBundle.message("replace.with.0", assertion)
    }

    override fun isApplicable(element: KtCallExpression): Boolean {
        return (element.replaceableAssertion() != null)
    }

    override fun applyTo(element: KtCallExpression, project: Project, editor: Editor?) {
        val valueArguments = element.valueArguments
        val condition = valueArguments.firstOrNull()?.getArgumentExpression() as? KtBinaryExpression ?: return
        val left = condition.left ?: return
        val right = condition.right ?: return
        val assertion = element.replaceableAssertion() ?: return

        val file = element.containingKtFile
        val factory = KtPsiFactory(project)
        val replaced = if (valueArguments.size == 2) {
            val message = valueArguments[1].getArgumentExpression() ?: return
            element.replaced(factory.createExpressionByPattern("$assertion($0, $1, $2)", left, right, message))
        } else {
            element.replaced(factory.createExpressionByPattern("$assertion($0, $1)", left, right))
        }
        ShortenReferences.DEFAULT.process(replaced)
        OptimizeImportsProcessor(project, file).run()
    }

    private fun KtCallExpression.replaceableAssertion(): String? {
        val referencedName = (calleeExpression as? KtNameReferenceExpression)?.getReferencedName() ?: return null
        if (referencedName !in assertions) {
            return null
        }

        val context = safeAnalyzeNonSourceRootCode(BodyResolveMode.PARTIAL)
        if (descriptor(context)?.containingDeclaration?.fqNameSafe != FqName(kotlinTestPackage)) {
            return null
        }

        if (valueArguments.size != 1 && valueArguments.size != 2) return null
        val binaryExpression = valueArguments.first().getArgumentExpression() as? KtBinaryExpression ?: return null
        val leftType = binaryExpression.left?.type(context) ?: return null
        val rightType = binaryExpression.right?.type(context) ?: return null
        if (!leftType.isSubtypeOf(rightType) && !rightType.isSubtypeOf(leftType)) return null
        val operationToken = binaryExpression.operationToken

        return assertionMap[Pair(referencedName, operationToken)]
    }

    private fun KtExpression.descriptor(context: BindingContext): CallableDescriptor? {
        return getResolvedCall(context)?.resultingDescriptor
    }

    private fun KtExpression.type(context: BindingContext): KotlinType? {
        return descriptor(context)?.returnType
    }

    companion object {
        private const val kotlinTestPackage = "kotlin.test"

        private val assertions = setOf("assertTrue", "assertFalse")

        private val assertionMap = mapOf(
            Pair("assertTrue", KtTokens.EQEQ) to "$kotlinTestPackage.assertEquals",
            Pair("assertTrue", KtTokens.EQEQEQ) to "$kotlinTestPackage.assertSame",
            Pair("assertFalse", KtTokens.EQEQ) to "$kotlinTestPackage.assertNotEquals",
            Pair("assertFalse", KtTokens.EQEQEQ) to "$kotlinTestPackage.assertNotSame"
        )
    }
}