diff options
Diffstat (limited to 'codegen/vulkan/scripts/xml_consistency.py')
-rwxr-xr-x | codegen/vulkan/scripts/xml_consistency.py | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/codegen/vulkan/scripts/xml_consistency.py b/codegen/vulkan/scripts/xml_consistency.py new file mode 100755 index 00000000..36514531 --- /dev/null +++ b/codegen/vulkan/scripts/xml_consistency.py @@ -0,0 +1,390 @@ +#!/usr/bin/python3 +# +# Copyright (c) 2019 Collabora, Ltd. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> +# +# Purpose: This script checks some "business logic" in the XML registry. + +import re +import sys +from pathlib import Path + +from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase +from reg import Registry +from spec_tools.consistency_tools import XMLChecker +from spec_tools.util import findNamedElem, getElemName, getElemType +from vkconventions import VulkanConventions as APIConventions + +# These are extensions which do not follow the usual naming conventions, +# specifying the alternate convention they follow +EXTENSION_ENUM_NAME_SPELLING_CHANGE = { + 'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE', +} + +# These are extensions whose names *look* like they end in version numbers, +# but don't +EXTENSION_NAME_VERSION_EXCEPTIONS = ( + 'VK_AMD_gpu_shader_int16', + 'VK_EXT_index_type_uint8', + 'VK_EXT_shader_image_atomic_int64', + 'VK_EXT_video_decode_h264', + 'VK_EXT_video_decode_h265', + 'VK_EXT_video_encode_h264', + 'VK_EXT_video_encode_h265', + 'VK_KHR_external_fence_win32', + 'VK_KHR_external_memory_win32', + 'VK_KHR_external_semaphore_win32', + 'VK_KHR_shader_atomic_int64', + 'VK_KHR_shader_float16_int8', + 'VK_KHR_spirv_1_4', + 'VK_NV_external_memory_win32', + 'VK_RESERVED_do_not_use_146', + 'VK_RESERVED_do_not_use_94', +) + +# Exceptions to pointer parameter naming rules +# Keyed by (entity name, type, name). +CHECK_PARAM_POINTER_NAME_EXCEPTIONS = { + ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None, +} + +# Exceptions to pNext member requiring an optional attribute +CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = ( + 'VkVideoEncodeInfoKHR', +) + +def get_extension_commands(reg): + extension_cmds = set() + for ext in reg.extensions: + for cmd in ext.findall("./require/command[@name]"): + extension_cmds.add(cmd.get("name")) + return extension_cmds + + +def get_enum_value_names(reg, enum_type): + names = set() + result_elem = reg.groupdict[enum_type].elem + for val in result_elem.findall("./enum[@name]"): + names.add(val.get("name")) + return names + + +# Regular expression matching an extension name ending in a (possible) version number +EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)') + +DESTROY_PREFIX = "vkDestroy" +TYPEENUM = "VkStructureType" + +SPECIFICATION_DIR = Path(__file__).parent.parent +REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*') + + +def get_extension_source(extname): + fn = '{}.txt'.format(extname) + return str(SPECIFICATION_DIR / 'appendices' / fn) + + +class EntityDatabase(OrigEntityDatabase): + + # Override base class method to not exclude 'disabled' extensions + def getExclusionSet(self): + """Return a set of "support=" attribute strings that should not be included in the database. + + Called only during construction.""" + + return set(()) + + def makeRegistry(self): + try: + import lxml.etree as etree + HAS_LXML = True + except ImportError: + HAS_LXML = False + if not HAS_LXML: + return super().makeRegistry() + + registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml') + registry = Registry() + registry.filename = registryFile + registry.loadElementTree(etree.parse(registryFile)) + return registry + + +class Checker(XMLChecker): + def __init__(self): + manual_types_to_codes = { + # These are hard-coded "manual" return codes: + # the codes of the value (string, list, or tuple) + # are available for a command if-and-only-if + # the key type is passed as an input. + "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED" + } + forward_only = { + # Like the above, but these are only valid in the + # "type implies return code" direction + } + reverse_only = { + # like the above, but these are only valid in the + # "return code implies type or its descendant" direction + # "XrDuration": "XR_TIMEOUT_EXPIRED" + } + # Some return codes are related in that only one of a set + # may be returned by a command + # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING) + self.exclusive_return_code_sets = tuple( + # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")), + ) + # Map of extension number -> [ list of extension names ] + self.extension_number_reservations = { + } + + # This is used to report collisions. + conventions = APIConventions() + db = EntityDatabase() + + self.extension_cmds = get_extension_commands(db.registry) + self.return_codes = get_enum_value_names(db.registry, 'VkResult') + self.structure_types = get_enum_value_names(db.registry, TYPEENUM) + + # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:") + # Keys are entity names, values are tuples or lists of message text to suppress. + suppressions = {} + + # Initialize superclass + super().__init__(entity_db=db, conventions=conventions, + manual_types_to_codes=manual_types_to_codes, + forward_only_types_to_codes=forward_only, + reverse_only_types_to_codes=reverse_only, + suppressions=suppressions) + + def check_command_return_codes_basic(self, name, info, + successcodes, errorcodes): + """Check a command's return codes for consistency. + + Called on every command.""" + # Check that all extension commands can return the code associated + # with trying to use an extension that wasn't enabled. + # if name in self.extension_cmds and UNSUPPORTED not in errorcodes: + # self.record_error("Missing expected return code", + # UNSUPPORTED, + # "implied due to being an extension command") + + codes = successcodes.union(errorcodes) + + # Check that all return codes are recognized. + unrecognized = codes - self.return_codes + if unrecognized: + self.record_error("Unrecognized return code(s):", + unrecognized) + + elem = info.elem + params = [(getElemName(elt), elt) for elt in elem.findall('param')] + + def is_count_output(name, elt): + # Must end with Count or Size, + # not be const, + # and be a pointer (detected by naming convention) + return (name.endswith('Count') or name.endswith('Size')) \ + and (elt.tail is None or 'const' not in elt.tail) \ + and (name.startswith('p')) + + countParams = [elt + for name, elt in params + if is_count_output(name, elt)] + if countParams: + assert(len(countParams) == 1) + if 'VK_INCOMPLETE' not in successcodes: + self.record_error( + "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.") + + elif 'VK_INCOMPLETE' in successcodes: + self.record_error( + "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.") + + def check_param(self, param): + """Check a member of a struct or a param of a function. + + Called from check_params.""" + super().check_param(param) + + if not self.is_api_type(param): + return + + param_text = "".join(param.itertext()) + param_name = getElemName(param) + + # Make sure the number of leading "p" matches the pointer count. + pointercount = param.find('type').tail + if pointercount: + pointercount = pointercount.count('*') + if pointercount: + prefix = 'p' * pointercount + if not param_name.startswith(prefix): + param_type = param.find('type').text + message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format( + param_text, prefix) + if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS: + self.record_warning('(Allowed exception)', message, elem=param) + else: + self.record_error(message, elem=param) + + # Make sure pNext members have optional="true" attributes + if param_name == self.conventions.nextpointer_member_name: + optional = param.get('optional') + if optional is None or optional != 'true': + message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity) + if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: + self.record_warning('(Allowed exception)', message, elem=param) + else: + self.record_error(message, elem=param) + + def check_type(self, name, info, category): + """Check a type's XML data for consistency. + + Called from check.""" + + elem = info.elem + type_elts = [elt + for elt in elem.findall("member") + if getElemType(elt) == TYPEENUM] + if category == 'struct' and type_elts: + if len(type_elts) > 1: + self.record_error( + "Have more than one member of type", TYPEENUM) + else: + type_elt = type_elts[0] + val = type_elt.get('values') + if val and val not in self.structure_types: + self.record_error("Unknown structure type constant", val) + + # Check the pointer chain member, if present. + next_name = self.conventions.nextpointer_member_name + next_member = findNamedElem(info.elem.findall('member'), next_name) + if next_member is not None: + # Ensure that the 'optional' attribute is set to 'true' + optional = next_member.get('optional') + if optional is None or optional != 'true': + message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name) + if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: + self.record_warning('(Allowed exception)', message) + else: + self.record_error(message) + + elif category == "bitmask": + if 'Flags' in name: + expected_require = name.replace('Flags', 'FlagBits') + require = info.elem.get('require') + if require is not None and expected_require != require: + self.record_error("Unexpected require attribute value:", + "got", require, + "but expected", expected_require) + super().check_type(name, info, category) + + def check_extension(self, name, info): + """Check an extension's XML data for consistency. + + Called from check.""" + elem = info.elem + enums = elem.findall('./require/enum[@name]') + + # Look for other extensions using that number + # Keep track of this extension number reservation + ext_number = elem.get('number') + if ext_number in self.extension_number_reservations: + conflicts = self.extension_number_reservations[ext_number] + self.record_error('Extension number {} has more than one reservation: {}, {}'.format( + ext_number, name, ', '.join(conflicts))) + self.extension_number_reservations[ext_number].append(name) + else: + self.extension_number_reservations[ext_number] = [ name ] + + # If extension name is not on the exception list and matches the + # versioned-extension pattern, map the extension name to the version + # name with the version as a separate word. Otherwise just map it to + # the upper-case version of the extension name. + + matches = EXTNAME_RE.fullmatch(name) + ext_versioned_name = False + if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE: + ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name) + elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS: + # This is the usual case, either a name that doesn't look + # versioned, or one that does but is on the exception list. + ext_enum_name = name.upper() + else: + # This is a versioned extension name. + # Treat the version number as a separate word. + base = matches.group('base') + version = matches.group('version') + ext_enum_name = base.upper() + '_' + version + # Keep track of this case + ext_versioned_name = True + + # Look for the expected SPEC_VERSION token name + version_name = "{}_SPEC_VERSION".format(ext_enum_name) + version_elem = findNamedElem(enums, version_name) + + if version_elem is None: + # Did not find a SPEC_VERSION enum matching the extension name + if ext_versioned_name: + suffix = '\n\ + Make sure that trailing version numbers in extension names are treated\n\ + as separate words in extension enumerant names. If this is an extension\n\ + whose name ends in a number which is not a version, such as "...h264"\n\ + or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\ + scripts/xml_consistency.py.' + else: + suffix = '' + self.record_error('Missing version enum {}{}'.format(version_name, suffix)) + elif info.elem.get('supported') == self.conventions.xml_api_name: + # Skip unsupported / disabled extensions for these checks + + fn = get_extension_source(name) + revisions = [] + with open(fn, 'r', encoding='utf-8') as fp: + for line in fp: + line = line.rstrip() + match = REVISION_RE.match(line) + if match: + revisions.append(int(match.group('num'))) + ver_from_xml = version_elem.get('value') + if revisions: + ver_from_text = str(max(revisions)) + if ver_from_xml != ver_from_text: + self.record_error("Version enum mismatch: spec text indicates", ver_from_text, + "but XML says", ver_from_xml) + else: + if ver_from_xml == '1': + self.record_warning( + "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'", + filename=fn) + else: + self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml, + " - make sure the spec text has lines starting exactly like '* Revision 1, ....'", + filename=fn) + + name_define = "{}_EXTENSION_NAME".format(ext_enum_name) + name_elem = findNamedElem(enums, name_define) + if name_elem is None: + self.record_error("Missing name enum", name_define) + else: + # Note: etree handles the XML entities here and turns " back into " + expected_name = '"{}"'.format(name) + name_val = name_elem.get('value') + if name_val != expected_name: + self.record_error("Incorrect name enum: expected", expected_name, + "got", name_val) + + super().check_extension(name, elem) + + +if __name__ == "__main__": + + ckr = Checker() + ckr.check() + + if ckr.fail: + sys.exit(1) |