summaryrefslogtreecommitdiff
path: root/codegen/vulkan/scripts/make_ext_dependency.py
blob: 51571286f398652dc9b90f053a6cc50def4edcbd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/env python3
#
# Copyright 2017-2021 The Khronos Group Inc.
#
# SPDX-License-Identifier: Apache-2.0
"""Generate a mapping of extension name -> all required extension names for that extension.

This script generates a list of all extensions, and of just KHR
extensions, that are placed into a Bash script and/or Python script. This
script can then be sources or executed to set a variable (e.g. khrExts),
Frontend scripts such as 'makeAllExts' and 'makeKHR' use this information
to set the EXTENSIONS Makefile variable when building the spec.

Sample Usage:

python3 scripts/make_ext_dependency.py -outscript=temp.sh
source temp.sh
make EXTENSIONS="$khrExts" html
rm temp.sh
"""

import argparse
import errno
import xml.etree.ElementTree as etree
from pathlib import Path

from vkconventions import VulkanConventions as APIConventions


def enQuote(key):
    return "'" + str(key) + "'"


def shList(names):
    """Return a sortable (list or set) of names as a string encoding
    of a Bash or Python list, sorted on the names."""
    s = ('"' +
         ' '.join(str(key) for key in sorted(names)) +
         '"')
    return s


def pyList(names):
    if names is not None:
        return ('[ ' +
                ', '.join(enQuote(key) for key in sorted(names)) +
                ' ]')
    else:
        return '[ ]'

class DiGraph:
    """A directed graph.

    The implementation and API mimic that of networkx.DiGraph in networkx-1.11.
    networkx implements graphs as nested dicts; it's dicts all the way down, no
    lists.

    Some major differences between this implementation and that of
    networkx-1.11 are:

        * This omits edge and node attribute data, because we never use them
          yet they add additional code complexity.

        * This returns iterator objects when possible instead of collection
          objects, because it simplifies the implementation and should provide
          better performance.
    """

    def __init__(self):
        self.__nodes = {}

    def add_node(self, node):
        if node not in self.__nodes:
            self.__nodes[node] = DiGraphNode()

    def add_edge(self, src, dest):
        self.add_node(src)
        self.add_node(dest)
        self.__nodes[src].adj.add(dest)

    def nodes(self):
        """Iterate over the nodes in the graph."""
        return self.__nodes.keys()

    def descendants(self, node):
        """
        Iterate over the nodes reachable from the given start node, excluding
        the start node itself. Each node in the graph is yielded at most once.
        """

        # Implementation detail: Do a breadth-first traversal because it's
        # easier than depth-first.

        # All nodes seen during traversal.
        seen = set()

        # The stack of nodes that need visiting.
        visit_me = []

        # Bootstrap the traversal.
        seen.add(node)
        for x in self.__nodes[node].adj:
            if x not in seen:
                seen.add(x)
                visit_me.append(x)

        while visit_me:
            x = visit_me.pop()
            assert x in seen
            yield x

            for y in self.__nodes[x].adj:
                if y not in seen:
                    seen.add(y)
                    visit_me.append(y)


class DiGraphNode:

    def __init__(self):
        # Set of adjacent of nodes.
        self.adj = set()


def make_dir(fn):
    outdir = Path(fn).parent
    try:
        outdir.mkdir(parents=True)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


# API conventions object
conventions = APIConventions()

# -extension name - may be a single extension name, a space-separated list
# of names, or a regular expression.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-registry', action='store',
                        default=conventions.registry_path,
                        help='Use specified registry file instead of ' + conventions.registry_path)
    parser.add_argument('-outscript', action='store',
                        default=None,
                        help='Shell script to create')
    parser.add_argument('-outpy', action='store',
                        default=None,
                        help='Python script to create')
    parser.add_argument('-test', action='store',
                        default=None,
                        help='Specify extension to find dependencies of')
    parser.add_argument('-quiet', action='store_true', default=False,
                        help='Suppress script output during normal execution.')

    args = parser.parse_args()

    tree = etree.parse(args.registry)

    # Loop over all supported extensions, creating a digraph of the
    # extension dependencies in the 'requires' attribute, which is a
    # comma-separated list of extension names. Also track lists of
    # all extensions and all KHR extensions.

    allExts = set()
    khrExts = set()
    g = DiGraph()

    for elem in tree.findall('extensions/extension'):
        name = elem.get('name')
        supported = elem.get('supported')

        # This works for the present form of the 'supported' attribute,
        # which is a comma-separate list of XML API names
        if conventions.xml_api_name in supported.split(','):
            allExts.add(name)

            if 'KHR' in name:
                khrExts.add(name)

            deps = elem.get('requires')
            if deps:
                deps = deps.split(',')

                for dep in deps:
                    g.add_edge(name, dep)
            else:
                g.add_node(name)
        else:
            # Skip unsupported extensions
            pass

    if args.outscript:
        make_dir(args.outscript)
        fp = open(args.outscript, 'w', encoding='utf-8')

        print('#!/bin/bash', file=fp)
        print('# Generated from make_ext_dependency.py', file=fp)
        print('# Specify maps of all extensions required by an enabled extension', file=fp)
        print('', file=fp)
        print('declare -A extensions', file=fp)

        # When printing lists of extensions, sort them so that the output script
        # remains as stable as possible as extensions are added to the API XML.

        for ext in sorted(g.nodes()):
            children = list(g.descendants(ext))

            # Only emit an ifdef block if an extension has dependencies
            if children:
                print('extensions[' + ext + ']=' + shList(children), file=fp)

        print('', file=fp)
        print('# Define lists of all extensions and KHR extensions', file=fp)
        print('allExts=' + shList(allExts), file=fp)
        print('khrExts=' + shList(khrExts), file=fp)

        fp.close()

    if args.outpy:
        make_dir(args.outpy)
        fp = open(args.outpy, 'w', encoding='utf-8')

        print('#!/usr/bin/env python', file=fp)
        print('# Generated from make_ext_dependency.py', file=fp)
        print('# Specify maps of all extensions required by an enabled extension', file=fp)
        print('', file=fp)
        print('extensions = {}', file=fp)

        # When printing lists of extensions, sort them so that the output script
        # remains as stable as possible as extensions are added to the API XML.

        for ext in sorted(g.nodes()):
            children = list(g.descendants(ext))
            print("extensions['" + ext + "'] = " + pyList(children), file=fp)

        print('', file=fp)
        print('# Define lists of all extensions and KHR extensions', file=fp)
        print('allExts = ' + pyList(allExts), file=fp)
        print('khrExts = ' + pyList(khrExts), file=fp)

        fp.close()