diff options
Diffstat (limited to 'codegen/vulkan/scripts/spec_tools/consistency_tools.py')
-rw-r--r-- | codegen/vulkan/scripts/spec_tools/consistency_tools.py | 697 |
1 files changed, 697 insertions, 0 deletions
diff --git a/codegen/vulkan/scripts/spec_tools/consistency_tools.py b/codegen/vulkan/scripts/spec_tools/consistency_tools.py new file mode 100644 index 00000000..c256a724 --- /dev/null +++ b/codegen/vulkan/scripts/spec_tools/consistency_tools.py @@ -0,0 +1,697 @@ +#!/usr/bin/python3 -i +# +# Copyright (c) 2019 Collabora, Ltd. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> +"""Provides utilities to write a script to verify XML registry consistency.""" + +import re + +import networkx as nx + +from .algo import RecursiveMemoize +from .attributes import ExternSyncEntry, LengthEntry +from .data_structures import DictOfStringSets +from .util import findNamedElem, getElemName + + +class XMLChecker: + def __init__(self, entity_db, conventions, manual_types_to_codes=None, + forward_only_types_to_codes=None, + reverse_only_types_to_codes=None, + suppressions=None): + """Set up data structures. + + May extend - call: + `super().__init__(db, conventions, manual_types_to_codes)` + as the last statement in your function. + + manual_types_to_codes is a dictionary of hard-coded + "manual" return codes: + the codes of the value are available for a command if-and-only-if + the key type is passed as an input. + + forward_only_types_to_codes is additional entries to the above + that should only be used in the "forward" direction + (arg type implies return code) + + reverse_only_types_to_codes is additional entries to + manual_types_to_codes that should only be used in the + "reverse" direction + (return code implies arg type) + """ + self.fail = False + self.entity = None + self.errors = DictOfStringSets() + self.warnings = DictOfStringSets() + self.db = entity_db + self.reg = entity_db.registry + self.handle_data = HandleData(self.reg) + self.conventions = conventions + + self.CONST_RE = re.compile(r"\bconst\b") + self.ARRAY_RE = re.compile(r"\[[^]]+\]") + + # Init memoized properties + self._handle_data = None + + if not manual_types_to_codes: + manual_types_to_codes = {} + if not reverse_only_types_to_codes: + reverse_only_types_to_codes = {} + if not forward_only_types_to_codes: + forward_only_types_to_codes = {} + + reverse_codes = DictOfStringSets(reverse_only_types_to_codes) + forward_codes = DictOfStringSets(forward_only_types_to_codes) + for k, v in manual_types_to_codes.items(): + forward_codes.add(k, v) + reverse_codes.add(k, v) + + self.forward_only_manual_types_to_codes = forward_codes.get_dict() + self.reverse_only_manual_types_to_codes = reverse_codes.get_dict() + + # The presence of some types as input to a function imply the + # availability of some return codes. + self.input_type_to_codes = compute_type_to_codes( + self.handle_data, + forward_codes, + extra_op=self.add_extra_codes) + + # Some return codes require a type (or its child) in the input. + self.codes_requiring_input_type = compute_codes_requiring_type( + self.handle_data, + reverse_codes + ) + + specified_codes = set(self.codes_requiring_input_type.keys()) + for codes in self.forward_only_manual_types_to_codes.values(): + specified_codes.update(codes) + for codes in self.reverse_only_manual_types_to_codes.values(): + specified_codes.update(codes) + for codes in self.input_type_to_codes.values(): + specified_codes.update(codes) + + unrecognized = specified_codes - self.return_codes + if unrecognized: + raise RuntimeError("Return code mentioned in script that isn't in the registry: " + + ', '.join(unrecognized)) + + self.referenced_input_types = ReferencedTypes(self.db, self.is_input) + self.referenced_api_types = ReferencedTypes(self.db, self.is_api_type) + if not suppressions: + suppressions = {} + self.suppressions = DictOfStringSets(suppressions) + + def is_api_type(self, member_elem): + """Return true if the member/parameter ElementTree passed is from this API. + + May override or extend.""" + membertext = "".join(member_elem.itertext()) + + return self.conventions.type_prefix in membertext + + def is_input(self, member_elem): + """Return true if the member/parameter ElementTree passed is + considered "input". + + May override or extend.""" + membertext = "".join(member_elem.itertext()) + + if self.conventions.type_prefix not in membertext: + return False + + ret = True + # Const is always input. + if self.CONST_RE.search(membertext): + ret = True + + # Arrays and pointers that aren't const are always output. + elif "*" in membertext: + ret = False + elif self.ARRAY_RE.search(membertext): + ret = False + + return ret + + def add_extra_codes(self, types_to_codes): + """Add any desired entries to the types-to-codes DictOfStringSets + before performing "ancestor propagation". + + Passed to compute_type_to_codes as the extra_op. + + May override.""" + pass + + def should_skip_checking_codes(self, name): + """Return True if more than the basic validation of return codes should + be skipped for a command. + + May override.""" + + return self.conventions.should_skip_checking_codes + + def get_codes_for_command_and_type(self, cmd_name, type_name): + """Return a set of error codes expected due to having + an input argument of type type_name. + + The cmd_name is passed for use by extending methods. + + May extend.""" + return self.input_type_to_codes.get(type_name, set()) + + def check(self): + """Iterate through the registry, looking for consistency problems. + + Outputs error messages at the end.""" + # Iterate through commands, looking for consistency problems. + for name, info in self.reg.cmddict.items(): + self.set_error_context(entity=name, elem=info.elem) + + self.check_command(name, info) + + for name, info in self.reg.typedict.items(): + cat = info.elem.get('category') + if not cat: + # This is an external thing, skip it. + continue + self.set_error_context(entity=name, elem=info.elem) + + self.check_type(name, info, cat) + + # check_extension is called for all extensions, even 'disabled' + # ones, but some checks may be skipped depending on extension + # status. + for name, info in self.reg.extdict.items(): + self.set_error_context(entity=name, elem=info.elem) + self.check_extension(name, info) + + entities_with_messages = set( + self.errors.keys()).union(self.warnings.keys()) + if entities_with_messages: + print('xml_consistency/consistency_tools error and warning messages follow.') + + for entity in entities_with_messages: + print() + print('-------------------') + print('Messages for', entity) + print() + messages = self.errors.get(entity) + if messages: + for m in messages: + print('Error:', m) + + messages = self.warnings.get(entity) + if messages: + for m in messages: + print('Warning:', m) + + def check_param(self, param): + """Check a member of a struct or a param of a function. + + Called from check_params. + + May extend.""" + param_name = getElemName(param) + externsyncs = ExternSyncEntry.parse_externsync_from_param(param) + if externsyncs: + for entry in externsyncs: + if entry.entirely_extern_sync: + if len(externsyncs) > 1: + self.record_error("Comma-separated list in externsync attribute includes 'true' for", + param_name) + else: + # member name + # TODO only looking at the superficial feature here, + # not entry.param_ref_parts + if entry.member != param_name: + self.record_error("externsync attribute for", param_name, + "refers to some other member/parameter:", entry.member) + + def check_params(self, params): + """Check the members of a struct or params of a function. + + Called from check_type and check_command. + + May extend.""" + for param in params: + self.check_param(param) + + # Check for parameters referenced by len= attribute + lengths = LengthEntry.parse_len_from_param(param) + if lengths: + for entry in lengths: + if not entry.other_param_name: + continue + # TODO only looking at the superficial feature here, + # not entry.param_ref_parts + other_param = findNamedElem(params, entry.other_param_name) + if other_param is None: + self.record_error("References a non-existent parameter/member in the length of", + getElemName(param), ":", entry.other_param_name) + + def check_type(self, name, info, category): + """Check a type's XML data for consistency. + + Called from check. + + May extend.""" + if category == 'struct': + if not name.startswith(self.conventions.type_prefix): + self.record_error("Name does not start with", + self.conventions.type_prefix) + members = info.elem.findall('member') + self.check_params(members) + + # Check the structure type member, if present. + type_member = findNamedElem( + members, self.conventions.structtype_member_name) + if type_member is not None: + val = type_member.get('values') + if val: + expected = self.conventions.generate_structure_type_from_name( + name) + if val != expected: + self.record_error("Type has incorrect type-member value: expected", + expected, "got", val) + + elif category == "bitmask": + if 'Flags' not in name: + self.record_error("Name of bitmask doesn't include 'Flags'") + + def check_extension(self, name, info): + """Check an extension's XML data for consistency. + + Called from check. + + May extend.""" + pass + + def check_command(self, name, info): + """Check a command's XML data for consistency. + + Called from check. + + May extend.""" + elem = info.elem + + self.check_params(elem.findall('param')) + + # Some minimal return code checking + errorcodes = elem.get("errorcodes") + if errorcodes: + errorcodes = errorcodes.split(",") + else: + errorcodes = [] + + successcodes = elem.get("successcodes") + if successcodes: + successcodes = successcodes.split(",") + else: + successcodes = [] + + if not successcodes and not errorcodes: + # Early out if no return codes. + return + + # Create a set for each group of codes, and check that + # they aren't duplicated within or between groups. + errorcodes_set = set(errorcodes) + if len(errorcodes) != len(errorcodes_set): + self.record_error("Contains a duplicate in errorcodes") + + successcodes_set = set(successcodes) + if len(successcodes) != len(successcodes_set): + self.record_error("Contains a duplicate in successcodes") + + if not successcodes_set.isdisjoint(errorcodes_set): + self.record_error("Has errorcodes and successcodes that overlap") + + self.check_command_return_codes_basic( + name, info, successcodes_set, errorcodes_set) + + # Continue to further return code checking if not "complicated" + if not self.should_skip_checking_codes(name): + codes_set = successcodes_set.union(errorcodes_set) + self.check_command_return_codes( + name, info, successcodes_set, errorcodes_set, codes_set) + + def check_command_return_codes_basic(self, name, info, + successcodes, errorcodes): + """Check a command's return codes for consistency. + + Called from check_command on every command. + + May extend.""" + + # Check that all error codes include _ERROR_, + # and that no success codes do. + for code in errorcodes: + if "_ERROR_" not in code: + self.record_error( + code, "in errorcodes but doesn't contain _ERROR_") + + for code in successcodes: + if "_ERROR_" in code: + self.record_error(code, "in successcodes but contain _ERROR_") + + def check_command_return_codes(self, name, type_info, + successcodes, errorcodes, + codes): + """Check a command's return codes in-depth for consistency. + + Called from check_command, only if + `self.should_skip_checking_codes(name)` is False. + + May extend.""" + referenced_input = self.referenced_input_types[name] + referenced_types = self.referenced_api_types[name] + + # Check that we have all the codes we expect, based on input types. + for referenced_type in referenced_input: + required_codes = self.get_codes_for_command_and_type( + name, referenced_type) + missing_codes = required_codes - codes + if missing_codes: + path = self.referenced_input_types.shortest_path( + name, referenced_type) + path_str = " -> ".join(path) + self.record_error("Missing expected return code(s)", + ",".join(missing_codes), + "implied because of input of type", + referenced_type, + "found via path", + path_str) + + # Check that, for each code returned by this command that we can + # associate with a type, we have some type that can provide it. + # e.g. can't have INSTANCE_LOST without an Instance + # (or child of Instance). + for code in codes: + + required_types = self.codes_requiring_input_type.get(code) + if not required_types: + # This code doesn't have a known requirement + continue + + # TODO: do we look at referenced_types or referenced_input here? + # the latter is stricter + if not referenced_types.intersection(required_types): + self.record_error("Unexpected return code", code, + "- none of these types:", + required_types, + "found in the set of referenced types", + referenced_types) + + ### + # Utility properties/methods + ### + + def set_error_context(self, entity=None, elem=None): + """Set the entity and/or element for future record_error calls.""" + self.entity = entity + self.elem = elem + self.name = getElemName(elem) + self.entity_suppressions = self.suppressions.get(getElemName(elem)) + + def record_error(self, *args, **kwargs): + """Record failure and an error message for the current context.""" + message = " ".join((str(x) for x in args)) + + if self._is_message_suppressed(message): + return + + message = self._prepend_sourceline_to_message(message, **kwargs) + self.fail = True + self.errors.add(self.entity, message) + + def record_warning(self, *args, **kwargs): + """Record a warning message for the current context.""" + message = " ".join((str(x) for x in args)) + + if self._is_message_suppressed(message): + return + + message = self._prepend_sourceline_to_message(message, **kwargs) + self.warnings.add(self.entity, message) + + def _is_message_suppressed(self, message): + """Return True if the given message, for this entity, should be suppressed.""" + if not self.entity_suppressions: + return False + for suppress in self.entity_suppressions: + if suppress in message: + return True + + return False + + def _prepend_sourceline_to_message(self, message, **kwargs): + """Prepend a file and/or line reference to the message, if possible. + + If filename is given as a keyword argument, it is used on its own. + + If filename is not given, this will attempt to retrieve the filename and line from an XML element. + If 'elem' is given as a keyword argument and is not None, it is used to find the line. + If 'elem' is given as None, no XML elements are looked at. + If 'elem' is not supplied, the error context element is used. + + If using XML, the filename, if available, is retrieved from the Registry class. + If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen.""" + fn = kwargs.get('filename') + sourceline = None + + if fn is None: + elem = kwargs.get('elem', self.elem) + if elem is not None: + sourceline = getattr(elem, 'sourceline', None) + if self.reg.filename: + fn = self.reg.filename + + if fn is None and sourceline is None: + return message + + if fn is None: + return "Line {}: {}".format(sourceline, message) + + if sourceline is None: + return "{}: {}".format(fn, message) + + return "{}:{}: {}".format(fn, sourceline, message) + + +class HandleParents(RecursiveMemoize): + def __init__(self, handle_types): + self.handle_types = handle_types + + def compute(handle_type): + immediate_parent = self.handle_types[handle_type].elem.get( + 'parent') + + if immediate_parent is None: + # No parents, no need to recurse + return [] + + # Support multiple (alternate) parents + immediate_parents = immediate_parent.split(',') + + # Recurse, combine, and return + all_parents = immediate_parents[:] + for parent in immediate_parents: + all_parents.extend(self[parent]) + return all_parents + + super().__init__(compute, handle_types.keys()) + + +def _always_true(x): + return True + + +class ReferencedTypes(RecursiveMemoize): + """Find all types(optionally matching a predicate) that are referenced + by a struct or function, recursively.""" + + def __init__(self, db, predicate=None): + """Initialize. + + Provide an EntityDB object and a predicate function.""" + self.db = db + + self.predicate = predicate + if not self.predicate: + # Default predicate is "anything goes" + self.predicate = _always_true + + self._directly_referenced = {} + self.graph = nx.DiGraph() + + def compute(type_name): + """Compute and return all types referenced by type_name, recursively, that satisfy the predicate. + + Called by the [] operator in the base class.""" + types = self.directly_referenced(type_name) + if not types: + return types + + all_types = set() + all_types.update(types) + for t in types: + referenced = self[t] + if referenced is not None: + # If not leading to a cycle + all_types.update(referenced) + return all_types + + # Initialize base class + super().__init__(compute, permit_cycles=True) + + def shortest_path(self, source, target): + """Get the shortest path between one type/function name and another.""" + # Trigger computation + _ = self[source] + + return nx.algorithms.shortest_path(self.graph, source=source, target=target) + + def directly_referenced(self, type_name): + """Get all types referenced directly by type_name that satisfy the predicate. + + Memoizes its results.""" + if type_name not in self._directly_referenced: + members = self.db.getMemberElems(type_name) + if members: + types = ((member, member.find("type")) for member in members) + self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types + if type_elem is not None and self.predicate(member)) + + else: + self._directly_referenced[type_name] = set() + + # Update graph + self.graph.add_node(type_name) + self.graph.add_edges_from((type_name, t) + for t in self._directly_referenced[type_name]) + + return self._directly_referenced[type_name] + + +class HandleData: + """Data about all the handle types available in an API specification.""" + + def __init__(self, registry): + self.reg = registry + self._handle_types = None + self._ancestors = None + self._descendants = None + + @property + def handle_types(self): + """Return a dictionary of handle type names to type info.""" + if not self._handle_types: + # First time requested - compute it. + self._handle_types = { + type_name: type_info + for type_name, type_info in self.reg.typedict.items() + if type_info.elem.get('category') == 'handle' + } + return self._handle_types + + @property + def ancestors_dict(self): + """Return a dictionary of handle type names to sets of ancestors.""" + if not self._ancestors: + # First time requested - compute it. + self._ancestors = HandleParents(self.handle_types).get_dict() + return self._ancestors + + @property + def descendants_dict(self): + """Return a dictionary of handle type names to sets of descendants.""" + if not self._descendants: + # First time requested - compute it. + + handle_parents = self.ancestors_dict + + def get_descendants(handle): + return set(h for h in handle_parents.keys() + if handle in handle_parents[h]) + + self._descendants = { + h: get_descendants(h) + for h in handle_parents.keys() + } + return self._descendants + + +def compute_type_to_codes(handle_data, types_to_codes, extra_op=None): + """Compute a DictOfStringSets of input type to required return codes. + + - handle_data is a HandleData instance. + - d is a dictionary of type names to strings or string collections of + return codes. + - extra_op, if any, is called after populating the output from the input + dictionary, but before propagation of parent codes to child types. + extra_op is called with the in-progress DictOfStringSets. + + Returns a DictOfStringSets of input type name to set of required return + code names. + """ + # Initialize with the supplied "manual" codes + types_to_codes = DictOfStringSets(types_to_codes) + + # Dynamically generate more codes, if desired + if extra_op: + extra_op(types_to_codes) + + # Final post-processing + + # Any handle can result in its parent handle's codes too. + + handle_ancestors = handle_data.ancestors_dict + + extra_handle_codes = {} + for handle_type, ancestors in handle_ancestors.items(): + codes = set() + # The sets of return codes corresponding to each ancestor type. + ancestors_codes = (types_to_codes.get(ancestor, set()) + for ancestor in ancestors) + codes.union(*ancestors_codes) + # for parent_codes in ancestors_codes: + # codes.update(parent_codes) + extra_handle_codes[handle_type] = codes + + for handle_type, extras in extra_handle_codes.items(): + types_to_codes.add(handle_type, extras) + + return types_to_codes + + +def compute_codes_requiring_type(handle_data, types_to_codes, registry=None): + """Compute a DictOfStringSets of return codes to a set of input types able + to provide the ability to generate that code. + + handle_data is a HandleData instance. + d is a dictionary of input types to associated return codes(same format + as for input to compute_type_to_codes, may use same dict). + This will invert that relationship, and also permit any "child handles" + to satisfy a requirement for a parent in producing a code. + + Returns a DictOfStringSets of return code name to the set of parameter + types that would allow that return code. + """ + # Use DictOfStringSets to normalize the input into a dict with values + # that are sets of strings + in_dict = DictOfStringSets(types_to_codes) + + handle_descendants = handle_data.descendants_dict + + out = DictOfStringSets() + for in_type, code_set in in_dict.items(): + descendants = handle_descendants.get(in_type) + for code in code_set: + out.add(code, in_type) + if descendants: + out.add(code, descendants) + + return out |