Skip to content

Commit 4eca19b

Browse files
Merge pull request #324 from gilles-peskine-arm/psa-test_psa_constant_names-refactor_and_ka
test_psa_constant_names: support key agreement, better code structure
2 parents 2e6cbcd + 8fa1348 commit 4eca19b

File tree

3 files changed

+187
-86
lines changed

3 files changed

+187
-86
lines changed

tests/scripts/test_psa_constant_names.py

Lines changed: 157 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import argparse
11+
from collections import namedtuple
1112
import itertools
1213
import os
1314
import platform
@@ -60,12 +61,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
6061
from exc_value
6162

6263
class Inputs:
64+
# pylint: disable=too-many-instance-attributes
6365
"""Accumulate information about macros to test.
66+
6467
This includes macro names as well as information about their arguments
6568
when applicable.
6669
"""
6770

6871
def __init__(self):
72+
self.all_declared = set()
6973
# Sets of names per type
7074
self.statuses = set(['PSA_SUCCESS'])
7175
self.algorithms = set(['0xffffffff'])
@@ -86,11 +90,30 @@ def __init__(self):
8690
self.table_by_prefix = {
8791
'ERROR': self.statuses,
8892
'ALG': self.algorithms,
89-
'CURVE': self.ecc_curves,
90-
'GROUP': self.dh_groups,
93+
'ECC_CURVE': self.ecc_curves,
94+
'DH_GROUP': self.dh_groups,
9195
'KEY_TYPE': self.key_types,
9296
'KEY_USAGE': self.key_usage_flags,
9397
}
98+
# Test functions
99+
self.table_by_test_function = {
100+
# Any function ending in _algorithm also gets added to
101+
# self.algorithms.
102+
'key_type': [self.key_types],
103+
'ecc_key_types': [self.ecc_curves],
104+
'dh_key_types': [self.dh_groups],
105+
'hash_algorithm': [self.hash_algorithms],
106+
'mac_algorithm': [self.mac_algorithms],
107+
'cipher_algorithm': [],
108+
'hmac_algorithm': [self.mac_algorithms],
109+
'aead_algorithm': [self.aead_algorithms],
110+
'key_derivation_algorithm': [self.kdf_algorithms],
111+
'key_agreement_algorithm': [self.ka_algorithms],
112+
'asymmetric_signature_algorithm': [],
113+
'asymmetric_signature_wildcard': [self.algorithms],
114+
'asymmetric_encryption_algorithm': [],
115+
'other_algorithm': [],
116+
}
94117
# macro name -> list of argument names
95118
self.argspecs = {}
96119
# argument name -> list of values
@@ -99,8 +122,20 @@ def __init__(self):
99122
'tag_length': ['1', '63'],
100123
}
101124

125+
def get_names(self, type_word):
126+
"""Return the set of known names of values of the given type."""
127+
return {
128+
'status': self.statuses,
129+
'algorithm': self.algorithms,
130+
'ecc_curve': self.ecc_curves,
131+
'dh_group': self.dh_groups,
132+
'key_type': self.key_types,
133+
'key_usage': self.key_usage_flags,
134+
}[type_word]
135+
102136
def gather_arguments(self):
103137
"""Populate the list of values for macro arguments.
138+
104139
Call this after parsing all the inputs.
105140
"""
106141
self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
@@ -118,6 +153,7 @@ def _format_arguments(name, arguments):
118153

119154
def distribute_arguments(self, name):
120155
"""Generate macro calls with each tested argument set.
156+
121157
If name is a macro without arguments, just yield "name".
122158
If name is a macro with arguments, yield a series of
123159
"name(arg1,...,argN)" where each argument takes each possible
@@ -145,6 +181,9 @@ def distribute_arguments(self, name):
145181
except BaseException as e:
146182
raise Exception('distribute_arguments({})'.format(name)) from e
147183

184+
def generate_expressions(self, names):
185+
return itertools.chain(*map(self.distribute_arguments, names))
186+
148187
_argument_split_re = re.compile(r' *, *')
149188
@classmethod
150189
def _argument_split(cls, arguments):
@@ -154,7 +193,7 @@ def _argument_split(cls, arguments):
154193
# Groups: 1=macro name, 2=type, 3=argument list (optional).
155194
_header_line_re = \
156195
re.compile(r'#define +' +
157-
r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' +
196+
r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
158197
r'(?:\(([^\n()]*)\))?')
159198
# Regex of macro names to exclude.
160199
_excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
@@ -167,10 +206,6 @@ def _argument_split(cls, arguments):
167206
# Auxiliary macro whose name doesn't fit the usual patterns for
168207
# auxiliary macros.
169208
'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
170-
# PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
171-
# currently doesn't support them.
172-
'PSA_ALG_ECDH',
173-
'PSA_ALG_FFDH',
174209
# Deprecated aliases.
175210
'PSA_ERROR_UNKNOWN_ERROR',
176211
'PSA_ERROR_OCCUPIED_SLOT',
@@ -184,6 +219,7 @@ def parse_header_line(self, line):
184219
if not m:
185220
return
186221
name = m.group(1)
222+
self.all_declared.add(name)
187223
if re.search(self._excluded_name_re, name) or \
188224
name in self._excluded_names:
189225
return
@@ -200,26 +236,34 @@ def parse_header(self, filename):
200236
for line in lines:
201237
self.parse_header_line(line)
202238

239+
_macro_identifier_re = r'[A-Z]\w+'
240+
def generate_undeclared_names(self, expr):
241+
for name in re.findall(self._macro_identifier_re, expr):
242+
if name not in self.all_declared:
243+
yield name
244+
245+
def accept_test_case_line(self, function, argument):
246+
#pylint: disable=unused-argument
247+
undeclared = list(self.generate_undeclared_names(argument))
248+
if undeclared:
249+
raise Exception('Undeclared names in test case', undeclared)
250+
return True
251+
203252
def add_test_case_line(self, function, argument):
204253
"""Parse a test case data line, looking for algorithm metadata tests."""
254+
sets = []
205255
if function.endswith('_algorithm'):
206-
# As above, ECDH and FFDH algorithms are excluded for now.
207-
# Support for them will be added in the future.
208-
if 'ECDH' in argument or 'FFDH' in argument:
209-
return
210-
self.algorithms.add(argument)
211-
if function == 'hash_algorithm':
212-
self.hash_algorithms.add(argument)
213-
elif function in ['mac_algorithm', 'hmac_algorithm']:
214-
self.mac_algorithms.add(argument)
215-
elif function == 'aead_algorithm':
216-
self.aead_algorithms.add(argument)
217-
elif function == 'key_type':
218-
self.key_types.add(argument)
219-
elif function == 'ecc_key_types':
220-
self.ecc_curves.add(argument)
221-
elif function == 'dh_key_types':
222-
self.dh_groups.add(argument)
256+
sets.append(self.algorithms)
257+
if function == 'key_agreement_algorithm' and \
258+
argument.startswith('PSA_ALG_KEY_AGREEMENT('):
259+
# We only want *raw* key agreement algorithms as such, so
260+
# exclude ones that are already chained with a KDF.
261+
# Keep the expression as one to test as an algorithm.
262+
function = 'other_algorithm'
263+
sets += self.table_by_test_function[function]
264+
if self.accept_test_case_line(function, argument):
265+
for s in sets:
266+
s.add(argument)
223267

224268
# Regex matching a *.data line containing a test function call and
225269
# its arguments. The actual definition is partly positional, but this
@@ -233,9 +277,9 @@ def parse_test_cases(self, filename):
233277
if m:
234278
self.add_test_case_line(m.group(1), m.group(2))
235279

236-
def gather_inputs(headers, test_suites):
280+
def gather_inputs(headers, test_suites, inputs_class=Inputs):
237281
"""Read the list of inputs to test psa_constant_names with."""
238-
inputs = Inputs()
282+
inputs = inputs_class()
239283
for header in headers:
240284
inputs.parse_header(header)
241285
for test_cases in test_suites:
@@ -252,8 +296,10 @@ def remove_file_if_exists(filename):
252296
except OSError:
253297
pass
254298

255-
def run_c(options, type_word, names):
256-
"""Generate and run a program to print out numerical values for names."""
299+
def run_c(type_word, expressions, include_path=None, keep_c=False):
300+
"""Generate and run a program to print out numerical values for expressions."""
301+
if include_path is None:
302+
include_path = []
257303
if type_word == 'status':
258304
cast_to = 'long'
259305
printf_format = '%ld'
@@ -278,18 +324,18 @@ def run_c(options, type_word, names):
278324
int main(void)
279325
{
280326
''')
281-
for name in names:
327+
for expr in expressions:
282328
c_file.write(' printf("{}\\n", ({}) {});\n'
283-
.format(printf_format, cast_to, name))
329+
.format(printf_format, cast_to, expr))
284330
c_file.write(''' return 0;
285331
}
286332
''')
287333
c_file.close()
288334
cc = os.getenv('CC', 'cc')
289335
subprocess.check_call([cc] +
290-
['-I' + dir for dir in options.include] +
336+
['-I' + dir for dir in include_path] +
291337
['-o', exe_name, c_name])
292-
if options.keep_c:
338+
if keep_c:
293339
sys.stderr.write('List of {} tests kept at {}\n'
294340
.format(type_word, c_name))
295341
else:
@@ -302,76 +348,101 @@ def run_c(options, type_word, names):
302348
NORMALIZE_STRIP_RE = re.compile(r'\s+')
303349
def normalize(expr):
304350
"""Normalize the C expression so as not to care about trivial differences.
351+
305352
Currently "trivial differences" means whitespace.
306353
"""
307-
expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr))
308-
return expr.strip().split('\n')
309-
310-
def do_test(options, inputs, type_word, names):
311-
"""Test psa_constant_names for the specified type.
312-
Run program on names.
313-
Use inputs to figure out what arguments to pass to macros that
314-
take arguments.
315-
"""
316-
names = sorted(itertools.chain(*map(inputs.distribute_arguments, names)))
317-
values = run_c(options, type_word, names)
318-
output = subprocess.check_output([options.program, type_word] + values)
319-
outputs = output.decode('ascii').strip().split('\n')
320-
errors = [(type_word, name, value, output)
321-
for (name, value, output) in zip(names, values, outputs)
322-
if normalize(name) != normalize(output)]
323-
return len(names), errors
324-
325-
def report_errors(errors):
326-
"""Describe each case where the output is not as expected."""
327-
for type_word, name, value, output in errors:
328-
print('For {} "{}", got "{}" (value: {})'
329-
.format(type_word, name, output, value))
330-
331-
def run_tests(options, inputs):
332-
"""Run psa_constant_names on all the gathered inputs.
333-
Return a tuple (count, errors) where count is the total number of inputs
334-
that were tested and errors is the list of cases where the output was
335-
not as expected.
354+
return re.sub(NORMALIZE_STRIP_RE, '', expr)
355+
356+
def collect_values(inputs, type_word, include_path=None, keep_c=False):
357+
"""Generate expressions using known macro names and calculate their values.
358+
359+
Return a list of pairs of (expr, value) where expr is an expression and
360+
value is a string representation of its integer value.
336361
"""
337-
count = 0
338-
errors = []
339-
for type_word, names in [('status', inputs.statuses),
340-
('algorithm', inputs.algorithms),
341-
('ecc_curve', inputs.ecc_curves),
342-
('dh_group', inputs.dh_groups),
343-
('key_type', inputs.key_types),
344-
('key_usage', inputs.key_usage_flags)]:
345-
c, e = do_test(options, inputs, type_word, names)
346-
count += c
347-
errors += e
348-
return count, errors
362+
names = inputs.get_names(type_word)
363+
expressions = sorted(inputs.generate_expressions(names))
364+
values = run_c(type_word, expressions,
365+
include_path=include_path, keep_c=keep_c)
366+
return expressions, values
367+
368+
class Tests:
369+
"""An object representing tests and their results."""
370+
371+
Error = namedtuple('Error',
372+
['type', 'expression', 'value', 'output'])
373+
374+
def __init__(self, options):
375+
self.options = options
376+
self.count = 0
377+
self.errors = []
378+
379+
def run_one(self, inputs, type_word):
380+
"""Test psa_constant_names for the specified type.
381+
382+
Run the program on the names for this type.
383+
Use the inputs to figure out what arguments to pass to macros that
384+
take arguments.
385+
"""
386+
expressions, values = collect_values(inputs, type_word,
387+
include_path=self.options.include,
388+
keep_c=self.options.keep_c)
389+
output = subprocess.check_output([self.options.program, type_word] +
390+
values)
391+
outputs = output.decode('ascii').strip().split('\n')
392+
self.count += len(expressions)
393+
for expr, value, output in zip(expressions, values, outputs):
394+
if normalize(expr) != normalize(output):
395+
self.errors.append(self.Error(type=type_word,
396+
expression=expr,
397+
value=value,
398+
output=output))
399+
400+
def run_all(self, inputs):
401+
"""Run psa_constant_names on all the gathered inputs."""
402+
for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
403+
'key_type', 'key_usage']:
404+
self.run_one(inputs, type_word)
405+
406+
def report(self, out):
407+
"""Describe each case where the output is not as expected.
408+
409+
Write the errors to ``out``.
410+
Also write a total.
411+
"""
412+
for error in self.errors:
413+
out.write('For {} "{}", got "{}" (value: {})\n'
414+
.format(error.type, error.expression,
415+
error.output, error.value))
416+
out.write('{} test cases'.format(self.count))
417+
if self.errors:
418+
out.write(', {} FAIL\n'.format(len(self.errors)))
419+
else:
420+
out.write(' PASS\n')
421+
422+
HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
423+
TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
349424

350425
def main():
351426
parser = argparse.ArgumentParser(description=globals()['__doc__'])
352427
parser.add_argument('--include', '-I',
353428
action='append', default=['include'],
354429
help='Directory for header files')
355-
parser.add_argument('--program',
356-
default='programs/psa/psa_constant_names',
357-
help='Program to test')
358430
parser.add_argument('--keep-c',
359431
action='store_true', dest='keep_c', default=False,
360432
help='Keep the intermediate C file')
361433
parser.add_argument('--no-keep-c',
362434
action='store_false', dest='keep_c',
363435
help='Don\'t keep the intermediate C file (default)')
436+
parser.add_argument('--program',
437+
default='programs/psa/psa_constant_names',
438+
help='Program to test')
364439
options = parser.parse_args()
365-
headers = [os.path.join(options.include[0], 'psa', h)
366-
for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']]
367-
test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
368-
inputs = gather_inputs(headers, test_suites)
369-
count, errors = run_tests(options, inputs)
370-
report_errors(errors)
371-
if errors == []:
372-
print('{} test cases PASS'.format(count))
373-
else:
374-
print('{} test cases, {} FAIL'.format(count, len(errors)))
440+
headers = [os.path.join(options.include[0], h) for h in HEADERS]
441+
inputs = gather_inputs(headers, TEST_SUITES)
442+
tests = Tests(options)
443+
tests.run_all(inputs)
444+
tests.report(sys.stdout)
445+
if tests.errors:
375446
exit(1)
376447

377448
if __name__ == '__main__':

0 commit comments

Comments
 (0)