summaryrefslogtreecommitdiff
path: root/codegen/vulkan/scripts/xml_consistency.py
diff options
context:
space:
mode:
Diffstat (limited to 'codegen/vulkan/scripts/xml_consistency.py')
-rwxr-xr-xcodegen/vulkan/scripts/xml_consistency.py390
1 files changed, 390 insertions, 0 deletions
diff --git a/codegen/vulkan/scripts/xml_consistency.py b/codegen/vulkan/scripts/xml_consistency.py
new file mode 100755
index 00000000..36514531
--- /dev/null
+++ b/codegen/vulkan/scripts/xml_consistency.py
@@ -0,0 +1,390 @@
+#!/usr/bin/python3
+#
+# Copyright (c) 2019 Collabora, Ltd.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com>
+#
+# Purpose: This script checks some "business logic" in the XML registry.
+
+import re
+import sys
+from pathlib import Path
+
+from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
+from reg import Registry
+from spec_tools.consistency_tools import XMLChecker
+from spec_tools.util import findNamedElem, getElemName, getElemType
+from vkconventions import VulkanConventions as APIConventions
+
+# These are extensions which do not follow the usual naming conventions,
+# specifying the alternate convention they follow
+EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
+ 'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
+}
+
+# These are extensions whose names *look* like they end in version numbers,
+# but don't
+EXTENSION_NAME_VERSION_EXCEPTIONS = (
+ 'VK_AMD_gpu_shader_int16',
+ 'VK_EXT_index_type_uint8',
+ 'VK_EXT_shader_image_atomic_int64',
+ 'VK_EXT_video_decode_h264',
+ 'VK_EXT_video_decode_h265',
+ 'VK_EXT_video_encode_h264',
+ 'VK_EXT_video_encode_h265',
+ 'VK_KHR_external_fence_win32',
+ 'VK_KHR_external_memory_win32',
+ 'VK_KHR_external_semaphore_win32',
+ 'VK_KHR_shader_atomic_int64',
+ 'VK_KHR_shader_float16_int8',
+ 'VK_KHR_spirv_1_4',
+ 'VK_NV_external_memory_win32',
+ 'VK_RESERVED_do_not_use_146',
+ 'VK_RESERVED_do_not_use_94',
+)
+
+# Exceptions to pointer parameter naming rules
+# Keyed by (entity name, type, name).
+CHECK_PARAM_POINTER_NAME_EXCEPTIONS = {
+ ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None,
+}
+
+# Exceptions to pNext member requiring an optional attribute
+CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = (
+ 'VkVideoEncodeInfoKHR',
+)
+
+def get_extension_commands(reg):
+ extension_cmds = set()
+ for ext in reg.extensions:
+ for cmd in ext.findall("./require/command[@name]"):
+ extension_cmds.add(cmd.get("name"))
+ return extension_cmds
+
+
+def get_enum_value_names(reg, enum_type):
+ names = set()
+ result_elem = reg.groupdict[enum_type].elem
+ for val in result_elem.findall("./enum[@name]"):
+ names.add(val.get("name"))
+ return names
+
+
+# Regular expression matching an extension name ending in a (possible) version number
+EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)')
+
+DESTROY_PREFIX = "vkDestroy"
+TYPEENUM = "VkStructureType"
+
+SPECIFICATION_DIR = Path(__file__).parent.parent
+REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
+
+
+def get_extension_source(extname):
+ fn = '{}.txt'.format(extname)
+ return str(SPECIFICATION_DIR / 'appendices' / fn)
+
+
+class EntityDatabase(OrigEntityDatabase):
+
+ # Override base class method to not exclude 'disabled' extensions
+ def getExclusionSet(self):
+ """Return a set of "support=" attribute strings that should not be included in the database.
+
+ Called only during construction."""
+
+ return set(())
+
+ def makeRegistry(self):
+ try:
+ import lxml.etree as etree
+ HAS_LXML = True
+ except ImportError:
+ HAS_LXML = False
+ if not HAS_LXML:
+ return super().makeRegistry()
+
+ registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
+ registry = Registry()
+ registry.filename = registryFile
+ registry.loadElementTree(etree.parse(registryFile))
+ return registry
+
+
+class Checker(XMLChecker):
+ def __init__(self):
+ manual_types_to_codes = {
+ # These are hard-coded "manual" return codes:
+ # the codes of the value (string, list, or tuple)
+ # are available for a command if-and-only-if
+ # the key type is passed as an input.
+ "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED"
+ }
+ forward_only = {
+ # Like the above, but these are only valid in the
+ # "type implies return code" direction
+ }
+ reverse_only = {
+ # like the above, but these are only valid in the
+ # "return code implies type or its descendant" direction
+ # "XrDuration": "XR_TIMEOUT_EXPIRED"
+ }
+ # Some return codes are related in that only one of a set
+ # may be returned by a command
+ # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
+ self.exclusive_return_code_sets = tuple(
+ # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
+ )
+ # Map of extension number -> [ list of extension names ]
+ self.extension_number_reservations = {
+ }
+
+ # This is used to report collisions.
+ conventions = APIConventions()
+ db = EntityDatabase()
+
+ self.extension_cmds = get_extension_commands(db.registry)
+ self.return_codes = get_enum_value_names(db.registry, 'VkResult')
+ self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
+
+ # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
+ # Keys are entity names, values are tuples or lists of message text to suppress.
+ suppressions = {}
+
+ # Initialize superclass
+ super().__init__(entity_db=db, conventions=conventions,
+ manual_types_to_codes=manual_types_to_codes,
+ forward_only_types_to_codes=forward_only,
+ reverse_only_types_to_codes=reverse_only,
+ suppressions=suppressions)
+
+ def check_command_return_codes_basic(self, name, info,
+ successcodes, errorcodes):
+ """Check a command's return codes for consistency.
+
+ Called on every command."""
+ # Check that all extension commands can return the code associated
+ # with trying to use an extension that wasn't enabled.
+ # if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
+ # self.record_error("Missing expected return code",
+ # UNSUPPORTED,
+ # "implied due to being an extension command")
+
+ codes = successcodes.union(errorcodes)
+
+ # Check that all return codes are recognized.
+ unrecognized = codes - self.return_codes
+ if unrecognized:
+ self.record_error("Unrecognized return code(s):",
+ unrecognized)
+
+ elem = info.elem
+ params = [(getElemName(elt), elt) for elt in elem.findall('param')]
+
+ def is_count_output(name, elt):
+ # Must end with Count or Size,
+ # not be const,
+ # and be a pointer (detected by naming convention)
+ return (name.endswith('Count') or name.endswith('Size')) \
+ and (elt.tail is None or 'const' not in elt.tail) \
+ and (name.startswith('p'))
+
+ countParams = [elt
+ for name, elt in params
+ if is_count_output(name, elt)]
+ if countParams:
+ assert(len(countParams) == 1)
+ if 'VK_INCOMPLETE' not in successcodes:
+ self.record_error(
+ "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.")
+
+ elif 'VK_INCOMPLETE' in successcodes:
+ self.record_error(
+ "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.")
+
+ def check_param(self, param):
+ """Check a member of a struct or a param of a function.
+
+ Called from check_params."""
+ super().check_param(param)
+
+ if not self.is_api_type(param):
+ return
+
+ param_text = "".join(param.itertext())
+ param_name = getElemName(param)
+
+ # Make sure the number of leading "p" matches the pointer count.
+ pointercount = param.find('type').tail
+ if pointercount:
+ pointercount = pointercount.count('*')
+ if pointercount:
+ prefix = 'p' * pointercount
+ if not param_name.startswith(prefix):
+ param_type = param.find('type').text
+ message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format(
+ param_text, prefix)
+ if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS:
+ self.record_warning('(Allowed exception)', message, elem=param)
+ else:
+ self.record_error(message, elem=param)
+
+ # Make sure pNext members have optional="true" attributes
+ if param_name == self.conventions.nextpointer_member_name:
+ optional = param.get('optional')
+ if optional is None or optional != 'true':
+ message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity)
+ if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
+ self.record_warning('(Allowed exception)', message, elem=param)
+ else:
+ self.record_error(message, elem=param)
+
+ def check_type(self, name, info, category):
+ """Check a type's XML data for consistency.
+
+ Called from check."""
+
+ elem = info.elem
+ type_elts = [elt
+ for elt in elem.findall("member")
+ if getElemType(elt) == TYPEENUM]
+ if category == 'struct' and type_elts:
+ if len(type_elts) > 1:
+ self.record_error(
+ "Have more than one member of type", TYPEENUM)
+ else:
+ type_elt = type_elts[0]
+ val = type_elt.get('values')
+ if val and val not in self.structure_types:
+ self.record_error("Unknown structure type constant", val)
+
+ # Check the pointer chain member, if present.
+ next_name = self.conventions.nextpointer_member_name
+ next_member = findNamedElem(info.elem.findall('member'), next_name)
+ if next_member is not None:
+ # Ensure that the 'optional' attribute is set to 'true'
+ optional = next_member.get('optional')
+ if optional is None or optional != 'true':
+ message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name)
+ if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
+ self.record_warning('(Allowed exception)', message)
+ else:
+ self.record_error(message)
+
+ elif category == "bitmask":
+ if 'Flags' in name:
+ expected_require = name.replace('Flags', 'FlagBits')
+ require = info.elem.get('require')
+ if require is not None and expected_require != require:
+ self.record_error("Unexpected require attribute value:",
+ "got", require,
+ "but expected", expected_require)
+ super().check_type(name, info, category)
+
+ def check_extension(self, name, info):
+ """Check an extension's XML data for consistency.
+
+ Called from check."""
+ elem = info.elem
+ enums = elem.findall('./require/enum[@name]')
+
+ # Look for other extensions using that number
+ # Keep track of this extension number reservation
+ ext_number = elem.get('number')
+ if ext_number in self.extension_number_reservations:
+ conflicts = self.extension_number_reservations[ext_number]
+ self.record_error('Extension number {} has more than one reservation: {}, {}'.format(
+ ext_number, name, ', '.join(conflicts)))
+ self.extension_number_reservations[ext_number].append(name)
+ else:
+ self.extension_number_reservations[ext_number] = [ name ]
+
+ # If extension name is not on the exception list and matches the
+ # versioned-extension pattern, map the extension name to the version
+ # name with the version as a separate word. Otherwise just map it to
+ # the upper-case version of the extension name.
+
+ matches = EXTNAME_RE.fullmatch(name)
+ ext_versioned_name = False
+ if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE:
+ ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name)
+ elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS:
+ # This is the usual case, either a name that doesn't look
+ # versioned, or one that does but is on the exception list.
+ ext_enum_name = name.upper()
+ else:
+ # This is a versioned extension name.
+ # Treat the version number as a separate word.
+ base = matches.group('base')
+ version = matches.group('version')
+ ext_enum_name = base.upper() + '_' + version
+ # Keep track of this case
+ ext_versioned_name = True
+
+ # Look for the expected SPEC_VERSION token name
+ version_name = "{}_SPEC_VERSION".format(ext_enum_name)
+ version_elem = findNamedElem(enums, version_name)
+
+ if version_elem is None:
+ # Did not find a SPEC_VERSION enum matching the extension name
+ if ext_versioned_name:
+ suffix = '\n\
+ Make sure that trailing version numbers in extension names are treated\n\
+ as separate words in extension enumerant names. If this is an extension\n\
+ whose name ends in a number which is not a version, such as "...h264"\n\
+ or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\
+ scripts/xml_consistency.py.'
+ else:
+ suffix = ''
+ self.record_error('Missing version enum {}{}'.format(version_name, suffix))
+ elif info.elem.get('supported') == self.conventions.xml_api_name:
+ # Skip unsupported / disabled extensions for these checks
+
+ fn = get_extension_source(name)
+ revisions = []
+ with open(fn, 'r', encoding='utf-8') as fp:
+ for line in fp:
+ line = line.rstrip()
+ match = REVISION_RE.match(line)
+ if match:
+ revisions.append(int(match.group('num')))
+ ver_from_xml = version_elem.get('value')
+ if revisions:
+ ver_from_text = str(max(revisions))
+ if ver_from_xml != ver_from_text:
+ self.record_error("Version enum mismatch: spec text indicates", ver_from_text,
+ "but XML says", ver_from_xml)
+ else:
+ if ver_from_xml == '1':
+ self.record_warning(
+ "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'",
+ filename=fn)
+ else:
+ self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
+ " - make sure the spec text has lines starting exactly like '* Revision 1, ....'",
+ filename=fn)
+
+ name_define = "{}_EXTENSION_NAME".format(ext_enum_name)
+ name_elem = findNamedElem(enums, name_define)
+ if name_elem is None:
+ self.record_error("Missing name enum", name_define)
+ else:
+ # Note: etree handles the XML entities here and turns &quot; back into "
+ expected_name = '"{}"'.format(name)
+ name_val = name_elem.get('value')
+ if name_val != expected_name:
+ self.record_error("Incorrect name enum: expected", expected_name,
+ "got", name_val)
+
+ super().check_extension(name, elem)
+
+
+if __name__ == "__main__":
+
+ ckr = Checker()
+ ckr.check()
+
+ if ckr.fail:
+ sys.exit(1)