Skip to content

test_psa_constant_names: support key agreement, better code structure #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 157 additions & 86 deletions tests/scripts/test_psa_constant_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import argparse
from collections import namedtuple
import itertools
import os
import platform
Expand Down Expand Up @@ -60,12 +61,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
from exc_value

class Inputs:
# pylint: disable=too-many-instance-attributes
"""Accumulate information about macros to test.

This includes macro names as well as information about their arguments
when applicable.
"""

def __init__(self):
self.all_declared = set()
# Sets of names per type
self.statuses = set(['PSA_SUCCESS'])
self.algorithms = set(['0xffffffff'])
Expand All @@ -86,11 +90,30 @@ def __init__(self):
self.table_by_prefix = {
'ERROR': self.statuses,
'ALG': self.algorithms,
'CURVE': self.ecc_curves,
'GROUP': self.dh_groups,
'ECC_CURVE': self.ecc_curves,
'DH_GROUP': self.dh_groups,
'KEY_TYPE': self.key_types,
'KEY_USAGE': self.key_usage_flags,
}
# Test functions
self.table_by_test_function = {
# Any function ending in _algorithm also gets added to
# self.algorithms.
'key_type': [self.key_types],
'ecc_key_types': [self.ecc_curves],
'dh_key_types': [self.dh_groups],
'hash_algorithm': [self.hash_algorithms],
'mac_algorithm': [self.mac_algorithms],
'cipher_algorithm': [],
'hmac_algorithm': [self.mac_algorithms],
'aead_algorithm': [self.aead_algorithms],
'key_derivation_algorithm': [self.kdf_algorithms],
'key_agreement_algorithm': [self.ka_algorithms],
'asymmetric_signature_algorithm': [],
'asymmetric_signature_wildcard': [self.algorithms],
'asymmetric_encryption_algorithm': [],
'other_algorithm': [],
}
# macro name -> list of argument names
self.argspecs = {}
# argument name -> list of values
Expand All @@ -99,8 +122,20 @@ def __init__(self):
'tag_length': ['1', '63'],
}

def get_names(self, type_word):
"""Return the set of known names of values of the given type."""
return {
'status': self.statuses,
'algorithm': self.algorithms,
'ecc_curve': self.ecc_curves,
'dh_group': self.dh_groups,
'key_type': self.key_types,
'key_usage': self.key_usage_flags,
}[type_word]

def gather_arguments(self):
"""Populate the list of values for macro arguments.

Call this after parsing all the inputs.
"""
self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
Expand All @@ -118,6 +153,7 @@ def _format_arguments(name, arguments):

def distribute_arguments(self, name):
"""Generate macro calls with each tested argument set.

If name is a macro without arguments, just yield "name".
If name is a macro with arguments, yield a series of
"name(arg1,...,argN)" where each argument takes each possible
Expand Down Expand Up @@ -145,6 +181,9 @@ def distribute_arguments(self, name):
except BaseException as e:
raise Exception('distribute_arguments({})'.format(name)) from e

def generate_expressions(self, names):
return itertools.chain(*map(self.distribute_arguments, names))

_argument_split_re = re.compile(r' *, *')
@classmethod
def _argument_split(cls, arguments):
Expand All @@ -154,7 +193,7 @@ def _argument_split(cls, arguments):
# Groups: 1=macro name, 2=type, 3=argument list (optional).
_header_line_re = \
re.compile(r'#define +' +
r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' +
r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
r'(?:\(([^\n()]*)\))?')
# Regex of macro names to exclude.
_excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
Expand All @@ -167,10 +206,6 @@ def _argument_split(cls, arguments):
# Auxiliary macro whose name doesn't fit the usual patterns for
# auxiliary macros.
'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
# PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
# currently doesn't support them.
'PSA_ALG_ECDH',
'PSA_ALG_FFDH',
# Deprecated aliases.
'PSA_ERROR_UNKNOWN_ERROR',
'PSA_ERROR_OCCUPIED_SLOT',
Expand All @@ -184,6 +219,7 @@ def parse_header_line(self, line):
if not m:
return
name = m.group(1)
self.all_declared.add(name)
if re.search(self._excluded_name_re, name) or \
name in self._excluded_names:
return
Expand All @@ -200,26 +236,34 @@ def parse_header(self, filename):
for line in lines:
self.parse_header_line(line)

_macro_identifier_re = r'[A-Z]\w+'
def generate_undeclared_names(self, expr):
for name in re.findall(self._macro_identifier_re, expr):
if name not in self.all_declared:
yield name

def accept_test_case_line(self, function, argument):
#pylint: disable=unused-argument
undeclared = list(self.generate_undeclared_names(argument))
if undeclared:
raise Exception('Undeclared names in test case', undeclared)
return True

def add_test_case_line(self, function, argument):
"""Parse a test case data line, looking for algorithm metadata tests."""
sets = []
if function.endswith('_algorithm'):
# As above, ECDH and FFDH algorithms are excluded for now.
# Support for them will be added in the future.
if 'ECDH' in argument or 'FFDH' in argument:
return
self.algorithms.add(argument)
if function == 'hash_algorithm':
self.hash_algorithms.add(argument)
elif function in ['mac_algorithm', 'hmac_algorithm']:
self.mac_algorithms.add(argument)
elif function == 'aead_algorithm':
self.aead_algorithms.add(argument)
elif function == 'key_type':
self.key_types.add(argument)
elif function == 'ecc_key_types':
self.ecc_curves.add(argument)
elif function == 'dh_key_types':
self.dh_groups.add(argument)
sets.append(self.algorithms)
if function == 'key_agreement_algorithm' and \
argument.startswith('PSA_ALG_KEY_AGREEMENT('):
# We only want *raw* key agreement algorithms as such, so
# exclude ones that are already chained with a KDF.
# Keep the expression as one to test as an algorithm.
function = 'other_algorithm'
sets += self.table_by_test_function[function]
if self.accept_test_case_line(function, argument):
for s in sets:
s.add(argument)

# Regex matching a *.data line containing a test function call and
# its arguments. The actual definition is partly positional, but this
Expand All @@ -233,9 +277,9 @@ def parse_test_cases(self, filename):
if m:
self.add_test_case_line(m.group(1), m.group(2))

def gather_inputs(headers, test_suites):
def gather_inputs(headers, test_suites, inputs_class=Inputs):
"""Read the list of inputs to test psa_constant_names with."""
inputs = Inputs()
inputs = inputs_class()
for header in headers:
inputs.parse_header(header)
for test_cases in test_suites:
Expand All @@ -252,8 +296,10 @@ def remove_file_if_exists(filename):
except OSError:
pass

def run_c(options, type_word, names):
"""Generate and run a program to print out numerical values for names."""
def run_c(type_word, expressions, include_path=None, keep_c=False):
"""Generate and run a program to print out numerical values for expressions."""
if include_path is None:
include_path = []
if type_word == 'status':
cast_to = 'long'
printf_format = '%ld'
Expand All @@ -278,18 +324,18 @@ def run_c(options, type_word, names):
int main(void)
{
''')
for name in names:
for expr in expressions:
c_file.write(' printf("{}\\n", ({}) {});\n'
.format(printf_format, cast_to, name))
.format(printf_format, cast_to, expr))
c_file.write(''' return 0;
}
''')
c_file.close()
cc = os.getenv('CC', 'cc')
subprocess.check_call([cc] +
['-I' + dir for dir in options.include] +
['-I' + dir for dir in include_path] +
['-o', exe_name, c_name])
if options.keep_c:
if keep_c:
sys.stderr.write('List of {} tests kept at {}\n'
.format(type_word, c_name))
else:
Expand All @@ -302,76 +348,101 @@ def run_c(options, type_word, names):
NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr):
"""Normalize the C expression so as not to care about trivial differences.

Currently "trivial differences" means whitespace.
"""
expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr))
return expr.strip().split('\n')

def do_test(options, inputs, type_word, names):
"""Test psa_constant_names for the specified type.
Run program on names.
Use inputs to figure out what arguments to pass to macros that
take arguments.
"""
names = sorted(itertools.chain(*map(inputs.distribute_arguments, names)))
values = run_c(options, type_word, names)
output = subprocess.check_output([options.program, type_word] + values)
outputs = output.decode('ascii').strip().split('\n')
errors = [(type_word, name, value, output)
for (name, value, output) in zip(names, values, outputs)
if normalize(name) != normalize(output)]
return len(names), errors

def report_errors(errors):
"""Describe each case where the output is not as expected."""
for type_word, name, value, output in errors:
print('For {} "{}", got "{}" (value: {})'
.format(type_word, name, output, value))

def run_tests(options, inputs):
"""Run psa_constant_names on all the gathered inputs.
Return a tuple (count, errors) where count is the total number of inputs
that were tested and errors is the list of cases where the output was
not as expected.
return re.sub(NORMALIZE_STRIP_RE, '', expr)

def collect_values(inputs, type_word, include_path=None, keep_c=False):
"""Generate expressions using known macro names and calculate their values.

Return a list of pairs of (expr, value) where expr is an expression and
value is a string representation of its integer value.
"""
count = 0
errors = []
for type_word, names in [('status', inputs.statuses),
('algorithm', inputs.algorithms),
('ecc_curve', inputs.ecc_curves),
('dh_group', inputs.dh_groups),
('key_type', inputs.key_types),
('key_usage', inputs.key_usage_flags)]:
c, e = do_test(options, inputs, type_word, names)
count += c
errors += e
return count, errors
names = inputs.get_names(type_word)
expressions = sorted(inputs.generate_expressions(names))
values = run_c(type_word, expressions,
include_path=include_path, keep_c=keep_c)
return expressions, values

class Tests:
"""An object representing tests and their results."""

Error = namedtuple('Error',
['type', 'expression', 'value', 'output'])

def __init__(self, options):
self.options = options
self.count = 0
self.errors = []

def run_one(self, inputs, type_word):
"""Test psa_constant_names for the specified type.

Run the program on the names for this type.
Use the inputs to figure out what arguments to pass to macros that
take arguments.
"""
expressions, values = collect_values(inputs, type_word,
include_path=self.options.include,
keep_c=self.options.keep_c)
output = subprocess.check_output([self.options.program, type_word] +
values)
outputs = output.decode('ascii').strip().split('\n')
self.count += len(expressions)
for expr, value, output in zip(expressions, values, outputs):
if normalize(expr) != normalize(output):
self.errors.append(self.Error(type=type_word,
expression=expr,
value=value,
output=output))

def run_all(self, inputs):
"""Run psa_constant_names on all the gathered inputs."""
for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
'key_type', 'key_usage']:
self.run_one(inputs, type_word)

def report(self, out):
"""Describe each case where the output is not as expected.

Write the errors to ``out``.
Also write a total.
"""
for error in self.errors:
out.write('For {} "{}", got "{}" (value: {})\n'
.format(error.type, error.expression,
error.output, error.value))
out.write('{} test cases'.format(self.count))
if self.errors:
out.write(', {} FAIL\n'.format(len(self.errors)))
else:
out.write(' PASS\n')

HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']

def main():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
parser.add_argument('--include', '-I',
action='append', default=['include'],
help='Directory for header files')
parser.add_argument('--program',
default='programs/psa/psa_constant_names',
help='Program to test')
parser.add_argument('--keep-c',
action='store_true', dest='keep_c', default=False,
help='Keep the intermediate C file')
parser.add_argument('--no-keep-c',
action='store_false', dest='keep_c',
help='Don\'t keep the intermediate C file (default)')
parser.add_argument('--program',
default='programs/psa/psa_constant_names',
help='Program to test')
options = parser.parse_args()
headers = [os.path.join(options.include[0], 'psa', h)
for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']]
test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
inputs = gather_inputs(headers, test_suites)
count, errors = run_tests(options, inputs)
report_errors(errors)
if errors == []:
print('{} test cases PASS'.format(count))
else:
print('{} test cases, {} FAIL'.format(count, len(errors)))
headers = [os.path.join(options.include[0], h) for h in HEADERS]
inputs = gather_inputs(headers, TEST_SUITES)
tests = Tests(options)
tests.run_all(inputs)
tests.report(sys.stdout)
if tests.errors:
exit(1)

if __name__ == '__main__':
Expand Down
Loading