Skip to content

Commit 6e75eec

Browse files
authored
[mlir][spirv] Remove code for de-duplicating symbols in SPIR-V grammar (#111778)
SPIR-V grammar was updated in upstream to have an "aliases" field instead of duplicating symbols with same values. See KhronosGroup/SPIRV-Headers#447 for details.
1 parent 67c4857 commit 6e75eec

File tree

1 file changed

+10
-91
lines changed

1 file changed

+10
-91
lines changed

mlir/utils/spirv/gen_spirv_dialect.py

Lines changed: 10 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -127,44 +127,6 @@ def split_list_into_sublists(items):
127127
return chuncks
128128

129129

130-
def uniquify_enum_cases(lst):
131-
"""Prunes duplicate enum cases from the list.
132-
133-
Arguments:
134-
- lst: List whose elements are to be uniqued. Assumes each element is a
135-
(symbol, value) pair and elements already sorted according to value.
136-
137-
Returns:
138-
- A list with all duplicates removed. The elements are sorted according to
139-
value and, for each value, uniqued according to symbol.
140-
original list,
141-
- A map from deduplicated cases to the uniqued case.
142-
"""
143-
cases = lst
144-
uniqued_cases = []
145-
duplicated_cases = {}
146-
147-
# First sort according to the value
148-
cases.sort(key=lambda x: x[1])
149-
150-
# Then group them according to the value
151-
for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
152-
# For each value, sort according to the enumerant symbol.
153-
sorted_group = sorted(groups, key=lambda x: x[0])
154-
# Keep the "smallest" case, which is typically the symbol without extension
155-
# suffix. But we have special cases that we want to fix.
156-
case = sorted_group[0]
157-
for i in range(1, len(sorted_group)):
158-
duplicated_cases[sorted_group[i][0]] = case[0]
159-
if case[0] == "HlslSemanticGOOGLE":
160-
assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
161-
case = sorted_group[1]
162-
duplicated_cases[sorted_group[0][0]] = case[0]
163-
uniqued_cases.append(case)
164-
165-
return uniqued_cases, duplicated_cases
166-
167-
168130
def toposort(dag, sort_fn):
169131
"""Topologically sorts the given dag.
170132
@@ -197,14 +159,12 @@ def get_next_batch(dag):
197159
return sorted_nodes
198160

199161

200-
def toposort_capabilities(all_cases, capability_mapping):
162+
def toposort_capabilities(all_cases):
201163
"""Returns topologically sorted capability (symbol, value) pairs.
202164
203165
Arguments:
204166
- all_cases: all capability cases (containing symbol, value, and implied
205167
capabilities).
206-
- capability_mapping: mapping from duplicated capability symbols to the
207-
canonicalized symbol chosen for SPIRVBase.td.
208168
209169
Returns:
210170
A list containing topologically sorted capability (symbol, value) pairs.
@@ -215,50 +175,23 @@ def toposort_capabilities(all_cases, capability_mapping):
215175
# Get the current capability.
216176
cur = case["enumerant"]
217177
name_to_value[cur] = case["value"]
218-
# Ignore duplicated symbols.
219-
if cur in capability_mapping:
220-
continue
221178

222179
# Get capabilities implied by the current capability.
223180
prev = case.get("capabilities", [])
224-
uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
181+
uniqued_prev = set(prev)
225182
dag[cur] = uniqued_prev
226183

227184
sorted_caps = toposort(dag, lambda x: name_to_value[x])
228185
# Attach the capability's value as the second component of the pair.
229186
return [(c, name_to_value[c]) for c in sorted_caps]
230187

231188

232-
def get_capability_mapping(operand_kinds):
233-
"""Returns the capability mapping from duplicated cases to canonicalized ones.
234-
235-
Arguments:
236-
- operand_kinds: all operand kinds' grammar spec
237-
238-
Returns:
239-
- A map mapping from duplicated capability symbols to the canonicalized
240-
symbol chosen for SPIRVBase.td.
241-
"""
242-
# Find the operand kind for capability
243-
cap_kind = {}
244-
for kind in operand_kinds:
245-
if kind["kind"] == "Capability":
246-
cap_kind = kind
247-
248-
kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
249-
_, capability_mapping = uniquify_enum_cases(kind_cases)
250-
251-
return capability_mapping
252-
253-
254-
def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
189+
def get_availability_spec(enum_case, for_op, for_cap):
255190
"""Returns the availability specification string for the given enum case.
256191
257192
Arguments:
258193
- enum_case: the enum case to generate availability spec for. It may contain
259194
'version', 'lastVersion', 'extensions', or 'capabilities'.
260-
- capability_mapping: mapping from duplicated capability symbols to the
261-
canonicalized symbol chosen for SPIRVBase.td.
262195
- for_op: bool value indicating whether this is the availability spec for an
263196
op itself.
264197
- for_cap: bool value indicating whether this is the availability spec for
@@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
313246
if caps:
314247
canonicalized_caps = []
315248
for c in caps:
316-
if c in capability_mapping:
317-
canonicalized_caps.append(capability_mapping[c])
318-
else:
319-
canonicalized_caps.append(c)
249+
canonicalized_caps.append(c)
320250
prefixed_caps = [
321251
"SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
322252
]
@@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
357287
return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
358288

359289

360-
def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
290+
def gen_operand_kind_enum_attr(operand_kind):
361291
"""Generates the TableGen EnumAttr definition for the given operand kind.
362292
363293
Returns:
@@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name):
388318
# Special treatment for capability cases: we need to sort them topologically
389319
# because a capability can refer to another via the 'implies' field.
390320
kind_cases = toposort_capabilities(
391-
operand_kind["enumerants"], capability_mapping
321+
operand_kind["enumerants"]
392322
)
393323
else:
394324
kind_cases = [
395325
(case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
396326
]
397-
kind_cases, _ = uniquify_enum_cases(kind_cases)
398327
max_len = max([len(symbol) for (symbol, _) in kind_cases])
399328

400329
# Generate the definition for each enum case
@@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name):
412341
value = int(case_pair[1])
413342
avail = get_availability_spec(
414343
name_to_case_dict[name],
415-
capability_mapping,
416344
False,
417345
kind_name == "Capability",
418346
)
@@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
648576
]
649577
filter_list.extend(existing_kinds)
650578

651-
capability_mapping = get_capability_mapping(operand_kinds)
652-
653579
# Generate definitions for all enums in filter list
654580
defs = [
655-
gen_operand_kind_enum_attr(kind, capability_mapping)
581+
gen_operand_kind_enum_attr(kind)
656582
for kind in operand_kinds
657583
if kind["kind"] in filter_list
658584
]
@@ -762,7 +688,7 @@ def get_description(text, appendix):
762688

763689

764690
def get_op_definition(
765-
instruction, opname, doc, existing_info, capability_mapping, settings
691+
instruction, opname, doc, existing_info, settings
766692
):
767693
"""Generates the TableGen op definition for the given SPIR-V instruction.
768694
@@ -771,8 +697,6 @@ def get_op_definition(
771697
- doc: the instruction's SPIR-V HTML doc
772698
- existing_info: a dict containing potential manually specified sections for
773699
this instruction
774-
- capability_mapping: mapping from duplicated capability symbols to the
775-
canonicalized symbol chosen for SPIRVBase.td
776700
777701
Returns:
778702
- A string containing the TableGen op definition
@@ -840,7 +764,7 @@ def get_op_definition(
840764
operands = instruction.get("operands", [])
841765

842766
# Op availability
843-
avail = get_availability_spec(instruction, capability_mapping, True, False)
767+
avail = get_availability_spec(instruction, True, False)
844768
if avail:
845769
avail = "\n\n {0}".format(avail)
846770

@@ -1021,7 +945,7 @@ def extract_td_op_info(op_def):
1021945

1022946

1023947
def update_td_op_definitions(
1024-
path, instructions, docs, filter_list, inst_category, capability_mapping, settings
948+
path, instructions, docs, filter_list, inst_category, settings
1025949
):
1026950
"""Updates SPIRVOps.td with newly generated op definition.
1027951
@@ -1030,8 +954,6 @@ def update_td_op_definitions(
1030954
- instructions: SPIR-V JSON grammar for all instructions
1031955
- docs: SPIR-V HTML doc for all instructions
1032956
- filter_list: a list containing new opnames to include
1033-
- capability_mapping: mapping from duplicated capability symbols to the
1034-
canonicalized symbol chosen for SPIRVBase.td.
1035957
1036958
Returns:
1037959
- A string containing all the TableGen op definitions
@@ -1079,7 +1001,6 @@ def update_td_op_definitions(
10791001
opname,
10801002
docs[fixed_opname],
10811003
op_info_dict.get(opname, {"inst_category": inst_category}),
1082-
capability_mapping,
10831004
settings,
10841005
)
10851006
)
@@ -1186,14 +1107,12 @@ def update_td_op_definitions(
11861107
if args.new_inst is not None:
11871108
assert args.op_td_path is not None
11881109
docs = get_spirv_doc_from_html_spec(ext_html_url, args)
1189-
capability_mapping = get_capability_mapping(operand_kinds)
11901110
update_td_op_definitions(
11911111
args.op_td_path,
11921112
instructions,
11931113
docs,
11941114
args.new_inst,
11951115
args.inst_category,
1196-
capability_mapping,
11971116
args,
11981117
)
11991118
print("Done. Note that this script just generates a template; ", end="")

0 commit comments

Comments
 (0)