summaryrefslogtreecommitdiff
path: root/registry/vulkan/scripts/spec_tools/validity.py
blob: 745ba01354f8c2160e621dd794c441ef5a579ab7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/python3 -i
#
# Copyright 2013-2021 The Khronos Group Inc.
#
# SPDX-License-Identifier: Apache-2.0

import re


_A_VS_AN_RE = re.compile(r' a ([a-z]+:)?([aAeEiIoOxX]\w+\b)(?!:)')

_STARTS_WITH_MACRO_RE = re.compile(r'^[a-z]+:.*')


def _checkAnchorComponents(anchor):
    """Raise an exception if any component of a VUID anchor name is illegal."""
    if anchor:
        # Any other invalid things in an anchor name should be detected here.
        if any((' ' in anchor_part for anchor_part in anchor)):
            raise RuntimeError("Illegal component of a VUID anchor name!")


def _fix_a_vs_an(s):
    """Fix usage (often generated) of the indefinite article 'a' when 'an' is appropriate.

    Explicitly excludes the markup macros."""
    return _A_VS_AN_RE.sub(r' an \1\2', s)


class ValidityCollection:
    """Combines validity for a single entity."""

    def __init__(self, entity_name=None, conventions=None, strict=True):
        self.entity_name = entity_name
        self.conventions = conventions
        self.lines = []
        self.strict = strict

    def possiblyAddExtensionRequirement(self, extension_name, entity_preface):
        """Add an extension-related validity statement if required.

        entity_preface is a string that goes between "must be enabled prior to "
        and the name of the entity, and normally ends in a macro.
        For instance, might be "calling flink:" for a function.
        """
        if extension_name and not extension_name.startswith(self.conventions.api_version_prefix):
            msg = 'The {} extension must: be enabled prior to {}{}'.format(
                self.conventions.formatExtension(extension_name), entity_preface, self.entity_name)
            self.addValidityEntry(msg, anchor=('extension', 'notenabled'))

    def addValidityEntry(self, msg, anchor=None):
        """Add a validity entry, optionally with a VUID anchor.

        If any trailing arguments are supplied,
        an anchor is generated by concatenating them with dashes
        at the end of the VUID anchor name.
        """
        if not msg:
            raise RuntimeError("Tried to add a blank validity line!")
        parts = ['*']
        _checkAnchorComponents(anchor)
        if anchor:
            if not self.entity_name:
                raise RuntimeError('Cannot add a validity entry with an anchor to a collection that does not know its entity name.')
            parts.append('[[{}]]'.format(
                '-'.join(['VUID', self.entity_name] + list(anchor))))
        parts.append(msg)
        combined = _fix_a_vs_an(' '.join(parts))
        if combined in self.lines:
            raise RuntimeError("Duplicate validity added!")
        self.lines.append(combined)

    def addText(self, msg):
        """Add already formatted validity text."""
        if self.strict:
            raise RuntimeError('addText called when collection in strict mode')
        if not msg:
            return
        msg = msg.rstrip()
        if not msg:
            return
        self.lines.append(msg)

    def _extend(self, lines):
        lines = list(lines)
        dupes = set(lines).intersection(self.lines)
        if dupes:
            raise RuntimeError("The two sets contain some shared entries! " + str(dupes))
        self.lines.extend(lines)

    def __iadd__(self, other):
        """Perform += with a string, iterable, or ValidityCollection."""
        if other is None:
            pass
        elif isinstance(other, str):
            if self.strict:
                raise RuntimeError(
                    'Collection += a string when collection in strict mode')
            if not other:
                # empty string
                pass
            elif other.startswith('*'):
                # Handle already-formatted
                self.addText(other)
            else:
                # Do the formatting ourselves.
                self.addValidityEntry(other)
        elif isinstance(other, ValidityEntry):
            if other:
                if other.verbose:
                    print(self.entity_name, 'Appending', str(other))
                self.addValidityEntry(str(other), anchor=other.anchor)
        elif isinstance(other, ValidityCollection):
            if not self.entity_name == other.entity_name:
                raise RuntimeError(
                    "Trying to combine two ValidityCollections for different entities!")
            self._extend(other.lines)
        else:
            # Deal with other iterables.
            self._extend(other)

        return self

    def __bool__(self):
        """Is the collection non-empty?"""
        empty = not self.lines
        return not empty

    @property
    def text(self):
        """Access validity statements as a single string or None."""
        if not self.lines:
            return None
        return '\n'.join(self.lines) + '\n'

    def __str__(self):
        """Access validity statements as a single string or empty string."""
        if not self:
            return ''
        return self.text

    def __repr__(self):
        return '<ValidityCollection: {}>'.format(self.lines)


class ValidityEntry:
    """A single validity line in progress."""

    def __init__(self, text=None, anchor=None):
        """Prepare to add a validity entry, optionally with a VUID anchor.

        An anchor is generated by concatenating the elements of the anchor tuple with dashes
        at the end of the VUID anchor name.
        """
        _checkAnchorComponents(anchor)
        if isinstance(anchor, str):
            # anchor needs to be a tuple
            anchor = (anchor,)

        # VUID does not allow special chars except ":"
        if anchor is not None:
            anchor = [(anchor_value.replace('->', '::').replace('.', '::')) for anchor_value in anchor]

        self.anchor = anchor
        self.parts = []
        self.verbose = False
        if text:
            self.append(text)

    def append(self, part):
        """Append a part of a string.

        If this is the first entry part and the part doesn't start
        with a markup macro, the first character will be capitalized."""
        if not self.parts and not _STARTS_WITH_MACRO_RE.match(part):
            self.parts.append(part[:1].upper())
            self.parts.append(part[1:])
        else:
            self.parts.append(part)
        if self.verbose:
            print('ValidityEntry', id(self), 'after append:', str(self))

    def drop_end(self, n):
        """Remove up to n trailing characters from the string."""
        temp = str(self)
        n = min(len(temp), n)
        self.parts = [temp[:-n]]

    def __iadd__(self, other):
        """Perform += with a string,"""
        self.append(other)
        return self

    def __bool__(self):
        """Return true if we have something more than just an anchor."""
        empty = not self.parts
        return not empty

    def __str__(self):
        """Access validity statement as a single string or empty string."""
        if not self:
            raise RuntimeError("No parts added?")
        return ''.join(self.parts).strip()

    def __repr__(self):
        parts = ['<ValidityEntry: ']
        if self:
            parts.append('"')
            parts.append(str(self))
            parts.append('"')
        else:
            parts.append('EMPTY')
        if self.anchor:
            parts.append(', anchor={}'.format('-'.join(self.anchor)))
        parts.append('>')
        return ''.join(parts)