summaryrefslogtreecommitdiff
path: root/plugins/kotlin/idea/src/org/jetbrains/kotlin/idea/quickfix/createFromUsage/callableBuilder/typeUtils.kt
blob: 1075b875ec6cc2838888ceb0684a3a3fe7327b38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
// Copyright 2000-2021 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.

package org.jetbrains.kotlin.idea.quickfix.createFromUsage.callableBuilder

import com.intellij.refactoring.psi.SearchUtils
import org.jetbrains.kotlin.builtins.isFunctionType
import org.jetbrains.kotlin.cfg.pseudocode.*
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.idea.project.builtIns
import org.jetbrains.kotlin.idea.references.KtSimpleNameReference
import org.jetbrains.kotlin.idea.util.IdeDescriptorRenderers
import org.jetbrains.kotlin.idea.util.getDataFlowAwareTypes
import org.jetbrains.kotlin.idea.util.withoutRedundantAnnotations
import org.jetbrains.kotlin.incremental.components.NoLookupLocation
import org.jetbrains.kotlin.load.java.NULLABILITY_ANNOTATIONS
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.getAssignmentByLHS
import org.jetbrains.kotlin.psi.psiUtil.getNonStrictParentOfType
import org.jetbrains.kotlin.psi.psiUtil.getStrictParentOfType
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.bindingContextUtil.isUsedAsStatement
import org.jetbrains.kotlin.resolve.constants.IntegerLiteralTypeConstructor
import org.jetbrains.kotlin.resolve.descriptorUtil.resolveTopLevelClass
import org.jetbrains.kotlin.resolve.scopes.HierarchicalScope
import org.jetbrains.kotlin.resolve.scopes.utils.findClassifier
import org.jetbrains.kotlin.types.*
import org.jetbrains.kotlin.types.checker.KotlinTypeChecker
import org.jetbrains.kotlin.types.typeUtil.makeNotNullable
import org.jetbrains.kotlin.types.typeUtil.supertypes
import java.util.*

internal operator fun KotlinType.contains(inner: KotlinType): Boolean {
    return KotlinTypeChecker.DEFAULT.equalTypes(this, inner) || arguments.any { inner in it.type }
}

internal operator fun KotlinType.contains(descriptor: ClassifierDescriptor): Boolean {
    return constructor.declarationDescriptor == descriptor || arguments.any { descriptor in it.type }
}

internal fun KotlinType.decomposeIntersection(): List<KotlinType> {
    (constructor as? IntersectionTypeConstructor)?.let {
        return it.supertypes.flatMap { type -> type.decomposeIntersection() }
    }

    return listOf(this)
}

private fun KotlinType.renderSingle(typeParameterNameMap: Map<TypeParameterDescriptor, String>, fq: Boolean): String {
    val substitution = typeParameterNameMap.mapValues {
        val name = Name.identifier(it.value)

        val typeParameter = it.key

        var wrappingTypeParameter: TypeParameterDescriptor? = null
        val wrappingTypeConstructor = object : TypeConstructor by typeParameter.typeConstructor {
            override fun getDeclarationDescriptor() = wrappingTypeParameter
        }

        wrappingTypeParameter = object : TypeParameterDescriptor by typeParameter {
            override fun getName() = name
            override fun getTypeConstructor() = wrappingTypeConstructor
        }

        val defaultType = typeParameter.defaultType
        val wrappingType = KotlinTypeFactory.simpleTypeWithNonTrivialMemberScope(
            defaultType.annotations,
            wrappingTypeConstructor,
            defaultType.arguments,
            defaultType.isMarkedNullable,
            defaultType.memberScope
        )
        TypeProjectionImpl(wrappingType)
    }
        .mapKeys { it.key.typeConstructor }

    val typeToRender = TypeSubstitutor.create(substitution).substitute(this, Variance.INVARIANT)!!
    val renderer =
        if (fq) IdeDescriptorRenderers.SOURCE_CODE.withOptions {
            excludedTypeAnnotationClasses = NULLABILITY_ANNOTATIONS
        }
        else IdeDescriptorRenderers.SOURCE_CODE_SHORT_NAMES_NO_ANNOTATIONS
    return renderer.renderType(typeToRender)
}

private fun KotlinType.render(typeParameterNameMap: Map<TypeParameterDescriptor, String>, fq: Boolean): List<String> {
    return decomposeIntersection().map { it.renderSingle(typeParameterNameMap, fq) }
}

internal fun KotlinType.renderShort(typeParameterNameMap: Map<TypeParameterDescriptor, String>) = render(typeParameterNameMap, false)
internal fun KotlinType.renderLong(typeParameterNameMap: Map<TypeParameterDescriptor, String>) = render(typeParameterNameMap, true)

internal fun getTypeParameterNamesNotInScope(
    typeParameters: Collection<TypeParameterDescriptor>,
    scope: HierarchicalScope
): List<TypeParameterDescriptor> {
    return typeParameters.filter { typeParameter ->
        val classifier = scope.findClassifier(typeParameter.name, NoLookupLocation.FROM_IDE)
        classifier == null || classifier != typeParameter
    }
}

fun KotlinType.containsStarProjections(): Boolean = arguments.any { it.isStarProjection || it.type.containsStarProjections() }

fun KotlinType.getTypeParameters(): Set<TypeParameterDescriptor> {
    val visitedTypes = HashSet<KotlinType>()
    val typeParameters = LinkedHashSet<TypeParameterDescriptor>()

    fun traverseTypes(type: KotlinType) {
        if (!visitedTypes.add(type)) return

        val arguments = type.arguments
        if (arguments.isEmpty()) {
            val descriptor = type.constructor.declarationDescriptor
            if (descriptor is TypeParameterDescriptor) {
                typeParameters.add(descriptor)
            }
            return
        }

        arguments.forEach { traverseTypes(it.type) }
    }

    traverseTypes(this)
    return typeParameters
}

fun KtExpression.guessTypes(
    context: BindingContext,
    module: ModuleDescriptor,
    pseudocode: Pseudocode? = null,
    coerceUnusedToUnit: Boolean = true,
    allowErrorTypes: Boolean = false
): Array<KotlinType> {
    fun isAcceptable(type: KotlinType) = allowErrorTypes || !ErrorUtils.containsErrorType(type)

    if (coerceUnusedToUnit
        && this !is KtDeclaration
        && isUsedAsStatement(context)
        && getNonStrictParentOfType<KtAnnotationEntry>() == null
    ) return arrayOf(module.builtIns.unitType)

    val parent = parent

    // Type/Expected type may be wrong for the expression of KtWhenEntry when some branches have unresolved expressions
    if (parent is KtWhenEntry && parent.expression == this) {
        return parent
            .getStrictParentOfType<KtWhenExpression>()
            ?.guessTypes(context, module, pseudocode, coerceUnusedToUnit, allowErrorTypes) ?: arrayOf()
    }

    if (this !is KtWhenExpression) {
        // if we know the actual type of the expression
        val theType1 = context.getType(this)?.let {
            val constructor = it.constructor
            if (constructor is IntegerLiteralTypeConstructor) {
                constructor.getApproximatedType()
            } else {
                it
            }
        }
        if (theType1 != null && isAcceptable(theType1)) {
            return getDataFlowAwareTypes(this, context, theType1).toTypedArray()
        }
    }

    // expression has an expected type
    val theType2 = context[BindingContext.EXPECTED_EXPRESSION_TYPE, this]
    if (theType2 != null && isAcceptable(theType2)) return arrayOf(theType2.withoutRedundantAnnotations())

    return when {
        this is KtTypeConstraint -> {
            // expression itself is a type assertion
            val constraint = this
            arrayOf(context[BindingContext.TYPE, constraint.boundTypeReference]!!)
        }
        parent is KtTypeConstraint -> {
            // expression is on the left side of a type assertion
            arrayOf(context[BindingContext.TYPE, parent.boundTypeReference]!!)
        }
        this is KtDestructuringDeclarationEntry -> {
            // expression is on the lhs of a multi-declaration
            val typeRef = typeReference
            if (typeRef != null) {
                // and has a specified type
                arrayOf(context[BindingContext.TYPE, typeRef]!!)
            } else {
                // otherwise guess
                guessType(context)
            }
        }
        this is KtParameter -> {
            // expression is a parameter (e.g. declared in a for-loop)
            val typeRef = typeReference
            if (typeRef != null) {
                // and has a specified type
                arrayOf(context[BindingContext.TYPE, typeRef]!!)
            } else {
                // otherwise guess
                guessType(context)
            }
        }
        parent is KtProperty && parent.isLocal -> {
            // the expression is the RHS of a variable assignment with a specified type
            val typeRef = parent.typeReference
            if (typeRef != null) {
                // and has a specified type
                arrayOf(context[BindingContext.TYPE, typeRef]!!)
            } else {
                // otherwise guess, based on LHS
                parent.guessType(context)
            }
        }
        parent is KtPropertyDelegate -> {
            val variableDescriptor = context[BindingContext.DECLARATION_TO_DESCRIPTOR, parent.parent as KtProperty] as VariableDescriptor
            val delegateClassName = if (variableDescriptor.isVar) "ReadWriteProperty" else "ReadOnlyProperty"
            val delegateClass = module.resolveTopLevelClass(FqName("kotlin.properties.$delegateClassName"), NoLookupLocation.FROM_IDE)
                ?: return arrayOf(module.builtIns.anyType)
            val receiverType = (variableDescriptor.extensionReceiverParameter ?: variableDescriptor.dispatchReceiverParameter)?.type
                ?: module.builtIns.nullableNothingType
            val typeArguments = listOf(TypeProjectionImpl(receiverType), TypeProjectionImpl(variableDescriptor.type))
            arrayOf(TypeUtils.substituteProjectionsForParameters(delegateClass, typeArguments))
        }
        parent is KtStringTemplateEntryWithExpression && parent.expression == this -> {
            arrayOf(module.builtIns.stringType)
        }
        parent is KtBlockExpression && parent.statements.lastOrNull() == this && parent.parent is KtFunctionLiteral -> {
            parent.guessTypes(context, module, pseudocode, coerceUnusedToUnit)
        }
        parent is KtFunction -> {
            val functionDescriptor = context[BindingContext.DECLARATION_TO_DESCRIPTOR, parent] as? FunctionDescriptor ?: return arrayOf()
            val returnType = functionDescriptor.returnType
            if (returnType != null && isAcceptable(returnType)) return arrayOf(returnType)
            val functionalExpression: KtExpression? = when {
                parent is KtFunctionLiteral -> parent.parent as? KtLambdaExpression
                parent is KtNamedFunction && parent.name == null -> parent
                else -> null
            }
            if (functionalExpression == null) {
                functionDescriptor.overriddenDescriptors
                    .mapNotNull { it.returnType }
                    .firstOrNull { isAcceptable(it) }
                    ?.let { return arrayOf(it) }
                return arrayOf()
            }
            val lambdaTypes = functionalExpression.guessTypes(context, module, pseudocode?.parent, coerceUnusedToUnit)
            lambdaTypes.mapNotNull { it.getFunctionType()?.arguments?.lastOrNull()?.type }.toTypedArray()
        }
        else -> {
            pseudocode?.getElementValue(this)?.let {
                getExpectedTypePredicate(it, context, module.builtIns).getRepresentativeTypes().toTypedArray()
            } ?: arrayOf() // can't infer anything
        }
    }
}

private fun KotlinType.getFunctionType() = if (isFunctionType) this else supertypes().firstOrNull { it.isFunctionType }

private fun KtNamedDeclaration.guessType(context: BindingContext): Array<KotlinType> {
    val expectedTypes = SearchUtils.findAllReferences(this, useScope)!!.mapNotNullTo(HashSet<KotlinType>()) { ref ->
        if (ref is KtSimpleNameReference) {
            context[BindingContext.EXPECTED_EXPRESSION_TYPE, ref.expression]
        } else {
            null
        }
    }

    if (expectedTypes.isEmpty() || expectedTypes.any { expectedType -> ErrorUtils.containsErrorType(expectedType) }) {
        return arrayOf()
    }

    val theType = TypeIntersector.intersectTypes(expectedTypes)
    return if (theType != null) {
        arrayOf(theType.withoutRedundantAnnotations())
    } else {
        // intersection doesn't exist; let user make an imperfect choice
        expectedTypes.map { it.withoutRedundantAnnotations() }.toTypedArray()
    }
}

/**
 * Encapsulates a single type substitution of a <code>KotlinType</code> by another <code>KotlinType</code>.
 */
internal class KotlinTypeSubstitution(val forType: KotlinType, val byType: KotlinType)

internal fun KotlinType.substitute(substitution: KotlinTypeSubstitution, variance: Variance): KotlinType {
    val nullable = isMarkedNullable
    val currentType = makeNotNullable()

    return if (when (variance) {
            Variance.INVARIANT -> KotlinTypeChecker.DEFAULT.equalTypes(currentType, substitution.forType)
            Variance.IN_VARIANCE -> KotlinTypeChecker.DEFAULT.isSubtypeOf(currentType, substitution.forType)
            Variance.OUT_VARIANCE -> KotlinTypeChecker.DEFAULT.isSubtypeOf(substitution.forType, currentType)
        }
    ) {
        TypeUtils.makeNullableAsSpecified(substitution.byType, nullable)
    } else {
        val newArguments = arguments.zip(constructor.parameters).map { pair ->
            val (projection, typeParameter) = pair
            TypeProjectionImpl(Variance.INVARIANT, projection.type.substitute(substitution, typeParameter.variance))
        }
        KotlinTypeFactory.simpleTypeWithNonTrivialMemberScope(annotations, constructor, newArguments, isMarkedNullable, memberScope)
    }
}

fun KtExpression.getExpressionForTypeGuess() = getAssignmentByLHS()?.right ?: this

fun KtCallElement.getTypeInfoForTypeArguments(): List<TypeInfo> {
    return typeArguments.mapNotNull { it.typeReference?.let { TypeInfo(it, Variance.INVARIANT) } }
}

fun KtCallExpression.getParameterInfos(): List<ParameterInfo> {
    val anyType = this.builtIns.nullableAnyType
    return valueArguments.map {
        ParameterInfo(
            it.getArgumentExpression()?.let { TypeInfo(it, Variance.IN_VARIANCE) } ?: TypeInfo(anyType, Variance.IN_VARIANCE),
            it.getArgumentName()?.referenceExpression?.getReferencedName()
        )
    }
}

private fun TypePredicate.getRepresentativeTypes(): Set<KotlinType> {
    return when (this) {
        is SingleType -> Collections.singleton(targetType)
        is AllSubtypes -> Collections.singleton(upperBound)
        is ForAllTypes -> {
            if (typeSets.isEmpty()) AllTypes.getRepresentativeTypes()
            else typeSets.asSequence().map { it.getRepresentativeTypes() }.reduce { a, b -> a.intersect(b) }
        }
        is ForSomeType -> typeSets.flatMapTo(LinkedHashSet<KotlinType>()) { it.getRepresentativeTypes() }
        is AllTypes -> emptySet()
        else -> throw AssertionError("Invalid type predicate: $this")
    }
}