diff options
Diffstat (limited to 'codegen/vulkan/scripts/cereal/testing.py')
-rw-r--r-- | codegen/vulkan/scripts/cereal/testing.py | 399 |
1 files changed, 399 insertions, 0 deletions
diff --git a/codegen/vulkan/scripts/cereal/testing.py b/codegen/vulkan/scripts/cereal/testing.py new file mode 100644 index 00000000..898532fc --- /dev/null +++ b/codegen/vulkan/scripts/cereal/testing.py @@ -0,0 +1,399 @@ +# Copyright (c) 2018 The Android Open Source Project +# Copyright (c) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import copy + +from .common.codegen import CodeGen +from .common.vulkantypes import \ + VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator + +from .wrapperdefs import VulkanWrapperGenerator +from .wrapperdefs import EQUALITY_VAR_NAMES +from .wrapperdefs import EQUALITY_ON_FAIL_VAR +from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE +from .wrapperdefs import EQUALITY_RET_TYPE +from .wrapperdefs import API_PREFIX_EQUALITY +from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2 + +class VulkanEqualityCodegen(VulkanTypeIterator): + + def __init__(self, cgen, inputVars, onFailCompareVar, prefix): + self.cgen = cgen + self.inputVars = inputVars + self.onFailCompareVar = onFailCompareVar + self.prefix = prefix + + def makeAccess(varName, asPtr = True): + return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr) + + def makeLengthAccess(varName): + return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName) + + def makeLengthAccessGuard(varName): + return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName) + + self.exprAccessorLhs = makeAccess(self.inputVars[0]) + self.exprAccessorRhs = makeAccess(self.inputVars[1]) + + self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False) + self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False) + + self.lenAccessorLhs = makeLengthAccess(self.inputVars[0]) + self.lenAccessorRhs = makeLengthAccess(self.inputVars[1]) + + self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0]) + self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1]) + + self.checked = False + + def getTypeForCompare(self, vulkanType): + res = copy(vulkanType) + + if not vulkanType.accessibleAsPointer(): + res = res.getForAddressAccess() + + if vulkanType.staticArrExpr: + res = res.getForAddressAccess() + + return res + + def makeCastExpr(self, vulkanType): + return "(%s)" % ( + self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) + + def makeEqualExpr(self, lhs, rhs): + return "(%s) == (%s)" % (lhs, rhs) + + def makeEqualBufExpr(self, lhs, rhs, size): + return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size) + + def makeEqualStringExpr(self, lhs, rhs): + return "(strcmp(%s, %s) == 0)" % (lhs, rhs) + + def makeBothNotNullExpr(self, lhs, rhs): + return "(%s) && (%s)" % (lhs, rhs) + + def makeBothNullExpr(self, lhs, rhs): + return "!(%s) && !(%s)" % (lhs, rhs) + + def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""): + self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" % + (compareExpr, self.onFailCompareVar, + self.exprAccessorValueLhs(vulkanType), errMsg)) + + def onCheck(self, vulkanType): + + self.checked = True + + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + bothNull = self.makeBothNullExpr(accessLhs, accessRhs) + bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs) + nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull) + + self.compareWithConsequence( \ + nullMatchExpr, + vulkanType, + "Mismatch in optional field") + + skipStreamInternal = vulkanType.typeName == "void" + + if skipStreamInternal: + return + + self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs)) + + def endCheck(self, vulkanType): + + skipStreamInternal = vulkanType.typeName == "void" + if skipStreamInternal: + return + + if self.checked: + self.cgen.endIf() + self.checked = False + + def onCompoundType(self, vulkanType): + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + lenAccessLhs = self.lenAccessorLhs(vulkanType) + lenAccessRhs = self.lenAccessorRhs(vulkanType) + + lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) + lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) + + needNullCheck = vulkanType.pointerIndirectionLevels > 0 + + if needNullCheck: + bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) + self.cgen.beginIf(bothNotNullExpr) + + if lenAccessLhs is not None: + equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) + + self.compareWithConsequence( \ + equalLenExpr, + vulkanType, "Lengths not equal") + + loopVar = "i" + accessLhs = "%s + %s" % (accessLhs, loopVar) + accessRhs = "%s + %s" % (accessRhs, loopVar) + forInit = "uint32_t %s = 0" % loopVar + forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) + forIncr = "++%s" % loopVar + + if needNullCheck: + self.cgen.beginIf(equalLenExpr) + + if lenAccessGuardLhs is not None: + self.cgen.beginIf(lenAccessGuardLhs) + + self.cgen.beginFor(forInit, forCond, forIncr) + + self.cgen.funcCall(None, self.prefix + vulkanType.typeName, + [accessLhs, accessRhs, self.onFailCompareVar]) + + if lenAccessLhs is not None: + self.cgen.endFor() + if lenAccessGuardLhs is not None: + self.cgen.endIf() + if needNullCheck: + self.cgen.endIf() + + if needNullCheck: + self.cgen.endIf() + + def onString(self, vulkanType): + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) + bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) + nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) + + self.compareWithConsequence( \ + nullMatchExpr, + vulkanType, + "Mismatch in string pointer nullness") + + self.cgen.beginIf(bothNotNullExpr) + + self.compareWithConsequence( + self.makeEqualStringExpr(accessLhs, accessRhs), + vulkanType, "Unequal strings") + + self.cgen.endIf() + + def onStringArray(self, vulkanType): + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + lenAccessLhs = self.lenAccessorLhs(vulkanType) + lenAccessRhs = self.lenAccessorRhs(vulkanType) + + lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) + lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) + + bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) + bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) + nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) + + self.compareWithConsequence( \ + nullMatchExpr, + vulkanType, + "Mismatch in string array pointer nullness") + + equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) + + self.compareWithConsequence( \ + equalLenExpr, + vulkanType, "Lengths not equal in string array") + + self.compareWithConsequence( \ + equalLenExpr, + vulkanType, "Lengths not equal in string array") + + self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr)) + + loopVar = "i" + accessLhs = "*(%s + %s)" % (accessLhs, loopVar) + accessRhs = "*(%s + %s)" % (accessRhs, loopVar) + forInit = "uint32_t %s = 0" % loopVar + forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) + forIncr = "++%s" % loopVar + + if lenAccessGuardLhs is not None: + self.cgen.beginIf(lenAccessGuardLhs) + + self.cgen.beginFor(forInit, forCond, forIncr) + + self.compareWithConsequence( + self.makeEqualStringExpr(accessLhs, accessRhs), + vulkanType, "Unequal string in string array") + + self.cgen.endFor() + + if lenAccessGuardLhs is not None: + self.cgen.endIf() + + self.cgen.endIf() + + def onStaticArr(self, vulkanType): + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + lenAccessLhs = self.lenAccessorLhs(vulkanType) + + finalLenExpr = "%s * %s" % (lenAccessLhs, + self.cgen.sizeofExpr(vulkanType)) + + self.compareWithConsequence( + self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), + vulkanType, "Unequal static array") + + def onStructExtension(self, vulkanType): + lhs = self.exprAccessorLhs(vulkanType) + rhs = self.exprAccessorRhs(vulkanType) + + self.cgen.beginIf(lhs) + self.cgen.funcCall(None, self.prefix + "extension_struct", + [lhs, rhs, self.onFailCompareVar]) + self.cgen.endIf() + + def onPointer(self, vulkanType): + accessLhs = self.exprAccessorLhs(vulkanType) + accessRhs = self.exprAccessorRhs(vulkanType) + + skipStreamInternal = vulkanType.typeName == "void" + if skipStreamInternal: + return + + lenAccessLhs = self.lenAccessorLhs(vulkanType) + lenAccessRhs = self.lenAccessorRhs(vulkanType) + + if lenAccessLhs is not None: + self.compareWithConsequence( \ + self.makeEqualExpr(lenAccessLhs, lenAccessRhs), + vulkanType, "Lengths not equal") + + finalLenExpr = "%s * %s" % (lenAccessLhs, + self.cgen.sizeofExpr( + vulkanType.getForValueAccess())) + else: + finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess()) + + self.compareWithConsequence( + self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), + vulkanType, "Unequal dyn array") + + def onValue(self, vulkanType): + accessLhs = self.exprAccessorValueLhs(vulkanType) + accessRhs = self.exprAccessorValueRhs(vulkanType) + self.compareWithConsequence( + self.makeEqualExpr(accessLhs, accessRhs), vulkanType, + "Value not equal") + + +class VulkanTesting(VulkanWrapperGenerator): + + def __init__(self, module, typeInfo): + VulkanWrapperGenerator.__init__(self, module, typeInfo) + + self.codegen = CodeGen() + + self.equalityCodegen = \ + VulkanEqualityCodegen( + None, + EQUALITY_VAR_NAMES, + EQUALITY_ON_FAIL_VAR, + API_PREFIX_EQUALITY) + + self.knownDefs = {} + + self.extensionTestingPrototype = \ + VulkanAPI(API_PREFIX_EQUALITY + "extension_struct", + EQUALITY_RET_TYPE, + [STRUCT_EXTENSION_PARAM, + STRUCT_EXTENSION_PARAM2, + EQUALITY_ON_FAIL_VAR_TYPE]) + + def onBegin(self,): + VulkanWrapperGenerator.onBegin(self) + self.module.appendImpl(self.codegen.makeFuncDecl( + self.extensionTestingPrototype)) + + def onGenType(self, typeXml, name, alias): + VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) + + if name in self.knownDefs: + return + + category = self.typeInfo.categoryOf(name) + + if category in ["struct", "union"] and alias: + self.module.appendHeader( + self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name, + API_PREFIX_EQUALITY + alias)) + + if category in ["struct", "union"] and not alias: + + structInfo = self.typeInfo.structs[name] + + typeFromName = \ + lambda varname: makeVulkanTypeSimple(True, name, 1, varname) + + compareParams = \ + list(map(typeFromName, EQUALITY_VAR_NAMES)) + \ + [EQUALITY_ON_FAIL_VAR_TYPE] + + comparePrototype = \ + VulkanAPI(API_PREFIX_EQUALITY + name, + EQUALITY_RET_TYPE, + compareParams) + + def structCompareDef(cgen): + self.equalityCodegen.cgen = cgen + for member in structInfo.members: + iterateVulkanType(self.typeInfo, member, + self.equalityCodegen) + + self.module.appendHeader( + self.codegen.makeFuncDecl(comparePrototype)) + self.module.appendImpl( + self.codegen.makeFuncImpl(comparePrototype, structCompareDef)) + + def onGenCmd(self, cmdinfo, name, alias): + VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) + + def onEnd(self,): + VulkanWrapperGenerator.onEnd(self) + + def forEachExtensionCompare(ext, castedAccess, cgen): + cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name, + [castedAccess, + cgen.makeReinterpretCast( + STRUCT_EXTENSION_PARAM2.paramName, ext.name), + EQUALITY_ON_FAIL_VAR]) + + self.module.appendImpl( + self.codegen.makeFuncImpl( + self.extensionTestingPrototype, + lambda cgen: self.emitForEachStructExtension( + cgen, + EQUALITY_RET_TYPE, + STRUCT_EXTENSION_PARAM, + forEachExtensionCompare))) |