8
8
"""
9
9
10
10
import argparse
11
+ from collections import namedtuple
11
12
import itertools
12
13
import os
13
14
import platform
@@ -60,12 +61,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
60
61
from exc_value
61
62
62
63
class Inputs :
64
+ # pylint: disable=too-many-instance-attributes
63
65
"""Accumulate information about macros to test.
66
+
64
67
This includes macro names as well as information about their arguments
65
68
when applicable.
66
69
"""
67
70
68
71
def __init__ (self ):
72
+ self .all_declared = set ()
69
73
# Sets of names per type
70
74
self .statuses = set (['PSA_SUCCESS' ])
71
75
self .algorithms = set (['0xffffffff' ])
@@ -86,11 +90,30 @@ def __init__(self):
86
90
self .table_by_prefix = {
87
91
'ERROR' : self .statuses ,
88
92
'ALG' : self .algorithms ,
89
- 'CURVE ' : self .ecc_curves ,
90
- 'GROUP ' : self .dh_groups ,
93
+ 'ECC_CURVE ' : self .ecc_curves ,
94
+ 'DH_GROUP ' : self .dh_groups ,
91
95
'KEY_TYPE' : self .key_types ,
92
96
'KEY_USAGE' : self .key_usage_flags ,
93
97
}
98
+ # Test functions
99
+ self .table_by_test_function = {
100
+ # Any function ending in _algorithm also gets added to
101
+ # self.algorithms.
102
+ 'key_type' : [self .key_types ],
103
+ 'ecc_key_types' : [self .ecc_curves ],
104
+ 'dh_key_types' : [self .dh_groups ],
105
+ 'hash_algorithm' : [self .hash_algorithms ],
106
+ 'mac_algorithm' : [self .mac_algorithms ],
107
+ 'cipher_algorithm' : [],
108
+ 'hmac_algorithm' : [self .mac_algorithms ],
109
+ 'aead_algorithm' : [self .aead_algorithms ],
110
+ 'key_derivation_algorithm' : [self .kdf_algorithms ],
111
+ 'key_agreement_algorithm' : [self .ka_algorithms ],
112
+ 'asymmetric_signature_algorithm' : [],
113
+ 'asymmetric_signature_wildcard' : [self .algorithms ],
114
+ 'asymmetric_encryption_algorithm' : [],
115
+ 'other_algorithm' : [],
116
+ }
94
117
# macro name -> list of argument names
95
118
self .argspecs = {}
96
119
# argument name -> list of values
@@ -99,8 +122,20 @@ def __init__(self):
99
122
'tag_length' : ['1' , '63' ],
100
123
}
101
124
125
+ def get_names (self , type_word ):
126
+ """Return the set of known names of values of the given type."""
127
+ return {
128
+ 'status' : self .statuses ,
129
+ 'algorithm' : self .algorithms ,
130
+ 'ecc_curve' : self .ecc_curves ,
131
+ 'dh_group' : self .dh_groups ,
132
+ 'key_type' : self .key_types ,
133
+ 'key_usage' : self .key_usage_flags ,
134
+ }[type_word ]
135
+
102
136
def gather_arguments (self ):
103
137
"""Populate the list of values for macro arguments.
138
+
104
139
Call this after parsing all the inputs.
105
140
"""
106
141
self .arguments_for ['hash_alg' ] = sorted (self .hash_algorithms )
@@ -118,6 +153,7 @@ def _format_arguments(name, arguments):
118
153
119
154
def distribute_arguments (self , name ):
120
155
"""Generate macro calls with each tested argument set.
156
+
121
157
If name is a macro without arguments, just yield "name".
122
158
If name is a macro with arguments, yield a series of
123
159
"name(arg1,...,argN)" where each argument takes each possible
@@ -145,6 +181,9 @@ def distribute_arguments(self, name):
145
181
except BaseException as e :
146
182
raise Exception ('distribute_arguments({})' .format (name )) from e
147
183
184
+ def generate_expressions (self , names ):
185
+ return itertools .chain (* map (self .distribute_arguments , names ))
186
+
148
187
_argument_split_re = re .compile (r' *, *' )
149
188
@classmethod
150
189
def _argument_split (cls , arguments ):
@@ -154,7 +193,7 @@ def _argument_split(cls, arguments):
154
193
# Groups: 1=macro name, 2=type, 3=argument list (optional).
155
194
_header_line_re = \
156
195
re .compile (r'#define +' +
157
- r'(PSA_((?:KEY_ )?[A-Z]+)_\w+)' +
196
+ r'(PSA_((?:(?:DH|ECC|KEY)_ )?[A-Z]+)_\w+)' +
158
197
r'(?:\(([^\n()]*)\))?' )
159
198
# Regex of macro names to exclude.
160
199
_excluded_name_re = re .compile (r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z' )
@@ -167,10 +206,6 @@ def _argument_split(cls, arguments):
167
206
# Auxiliary macro whose name doesn't fit the usual patterns for
168
207
# auxiliary macros.
169
208
'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE' ,
170
- # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
171
- # currently doesn't support them.
172
- 'PSA_ALG_ECDH' ,
173
- 'PSA_ALG_FFDH' ,
174
209
# Deprecated aliases.
175
210
'PSA_ERROR_UNKNOWN_ERROR' ,
176
211
'PSA_ERROR_OCCUPIED_SLOT' ,
@@ -184,6 +219,7 @@ def parse_header_line(self, line):
184
219
if not m :
185
220
return
186
221
name = m .group (1 )
222
+ self .all_declared .add (name )
187
223
if re .search (self ._excluded_name_re , name ) or \
188
224
name in self ._excluded_names :
189
225
return
@@ -200,26 +236,34 @@ def parse_header(self, filename):
200
236
for line in lines :
201
237
self .parse_header_line (line )
202
238
239
+ _macro_identifier_re = r'[A-Z]\w+'
240
+ def generate_undeclared_names (self , expr ):
241
+ for name in re .findall (self ._macro_identifier_re , expr ):
242
+ if name not in self .all_declared :
243
+ yield name
244
+
245
+ def accept_test_case_line (self , function , argument ):
246
+ #pylint: disable=unused-argument
247
+ undeclared = list (self .generate_undeclared_names (argument ))
248
+ if undeclared :
249
+ raise Exception ('Undeclared names in test case' , undeclared )
250
+ return True
251
+
203
252
def add_test_case_line (self , function , argument ):
204
253
"""Parse a test case data line, looking for algorithm metadata tests."""
254
+ sets = []
205
255
if function .endswith ('_algorithm' ):
206
- # As above, ECDH and FFDH algorithms are excluded for now.
207
- # Support for them will be added in the future.
208
- if 'ECDH' in argument or 'FFDH' in argument :
209
- return
210
- self .algorithms .add (argument )
211
- if function == 'hash_algorithm' :
212
- self .hash_algorithms .add (argument )
213
- elif function in ['mac_algorithm' , 'hmac_algorithm' ]:
214
- self .mac_algorithms .add (argument )
215
- elif function == 'aead_algorithm' :
216
- self .aead_algorithms .add (argument )
217
- elif function == 'key_type' :
218
- self .key_types .add (argument )
219
- elif function == 'ecc_key_types' :
220
- self .ecc_curves .add (argument )
221
- elif function == 'dh_key_types' :
222
- self .dh_groups .add (argument )
256
+ sets .append (self .algorithms )
257
+ if function == 'key_agreement_algorithm' and \
258
+ argument .startswith ('PSA_ALG_KEY_AGREEMENT(' ):
259
+ # We only want *raw* key agreement algorithms as such, so
260
+ # exclude ones that are already chained with a KDF.
261
+ # Keep the expression as one to test as an algorithm.
262
+ function = 'other_algorithm'
263
+ sets += self .table_by_test_function [function ]
264
+ if self .accept_test_case_line (function , argument ):
265
+ for s in sets :
266
+ s .add (argument )
223
267
224
268
# Regex matching a *.data line containing a test function call and
225
269
# its arguments. The actual definition is partly positional, but this
@@ -233,9 +277,9 @@ def parse_test_cases(self, filename):
233
277
if m :
234
278
self .add_test_case_line (m .group (1 ), m .group (2 ))
235
279
236
- def gather_inputs (headers , test_suites ):
280
+ def gather_inputs (headers , test_suites , inputs_class = Inputs ):
237
281
"""Read the list of inputs to test psa_constant_names with."""
238
- inputs = Inputs ()
282
+ inputs = inputs_class ()
239
283
for header in headers :
240
284
inputs .parse_header (header )
241
285
for test_cases in test_suites :
@@ -252,8 +296,10 @@ def remove_file_if_exists(filename):
252
296
except OSError :
253
297
pass
254
298
255
- def run_c (options , type_word , names ):
256
- """Generate and run a program to print out numerical values for names."""
299
+ def run_c (type_word , expressions , include_path = None , keep_c = False ):
300
+ """Generate and run a program to print out numerical values for expressions."""
301
+ if include_path is None :
302
+ include_path = []
257
303
if type_word == 'status' :
258
304
cast_to = 'long'
259
305
printf_format = '%ld'
@@ -278,18 +324,18 @@ def run_c(options, type_word, names):
278
324
int main(void)
279
325
{
280
326
''' )
281
- for name in names :
327
+ for expr in expressions :
282
328
c_file .write (' printf("{}\\ n", ({}) {});\n '
283
- .format (printf_format , cast_to , name ))
329
+ .format (printf_format , cast_to , expr ))
284
330
c_file .write (''' return 0;
285
331
}
286
332
''' )
287
333
c_file .close ()
288
334
cc = os .getenv ('CC' , 'cc' )
289
335
subprocess .check_call ([cc ] +
290
- ['-I' + dir for dir in options . include ] +
336
+ ['-I' + dir for dir in include_path ] +
291
337
['-o' , exe_name , c_name ])
292
- if options . keep_c :
338
+ if keep_c :
293
339
sys .stderr .write ('List of {} tests kept at {}\n '
294
340
.format (type_word , c_name ))
295
341
else :
@@ -302,76 +348,101 @@ def run_c(options, type_word, names):
302
348
NORMALIZE_STRIP_RE = re .compile (r'\s+' )
303
349
def normalize (expr ):
304
350
"""Normalize the C expression so as not to care about trivial differences.
351
+
305
352
Currently "trivial differences" means whitespace.
306
353
"""
307
- expr = re .sub (NORMALIZE_STRIP_RE , '' , expr , len (expr ))
308
- return expr .strip ().split ('\n ' )
309
-
310
- def do_test (options , inputs , type_word , names ):
311
- """Test psa_constant_names for the specified type.
312
- Run program on names.
313
- Use inputs to figure out what arguments to pass to macros that
314
- take arguments.
315
- """
316
- names = sorted (itertools .chain (* map (inputs .distribute_arguments , names )))
317
- values = run_c (options , type_word , names )
318
- output = subprocess .check_output ([options .program , type_word ] + values )
319
- outputs = output .decode ('ascii' ).strip ().split ('\n ' )
320
- errors = [(type_word , name , value , output )
321
- for (name , value , output ) in zip (names , values , outputs )
322
- if normalize (name ) != normalize (output )]
323
- return len (names ), errors
324
-
325
- def report_errors (errors ):
326
- """Describe each case where the output is not as expected."""
327
- for type_word , name , value , output in errors :
328
- print ('For {} "{}", got "{}" (value: {})'
329
- .format (type_word , name , output , value ))
330
-
331
- def run_tests (options , inputs ):
332
- """Run psa_constant_names on all the gathered inputs.
333
- Return a tuple (count, errors) where count is the total number of inputs
334
- that were tested and errors is the list of cases where the output was
335
- not as expected.
354
+ return re .sub (NORMALIZE_STRIP_RE , '' , expr )
355
+
356
+ def collect_values (inputs , type_word , include_path = None , keep_c = False ):
357
+ """Generate expressions using known macro names and calculate their values.
358
+
359
+ Return a list of pairs of (expr, value) where expr is an expression and
360
+ value is a string representation of its integer value.
336
361
"""
337
- count = 0
338
- errors = []
339
- for type_word , names in [('status' , inputs .statuses ),
340
- ('algorithm' , inputs .algorithms ),
341
- ('ecc_curve' , inputs .ecc_curves ),
342
- ('dh_group' , inputs .dh_groups ),
343
- ('key_type' , inputs .key_types ),
344
- ('key_usage' , inputs .key_usage_flags )]:
345
- c , e = do_test (options , inputs , type_word , names )
346
- count += c
347
- errors += e
348
- return count , errors
362
+ names = inputs .get_names (type_word )
363
+ expressions = sorted (inputs .generate_expressions (names ))
364
+ values = run_c (type_word , expressions ,
365
+ include_path = include_path , keep_c = keep_c )
366
+ return expressions , values
367
+
368
+ class Tests :
369
+ """An object representing tests and their results."""
370
+
371
+ Error = namedtuple ('Error' ,
372
+ ['type' , 'expression' , 'value' , 'output' ])
373
+
374
+ def __init__ (self , options ):
375
+ self .options = options
376
+ self .count = 0
377
+ self .errors = []
378
+
379
+ def run_one (self , inputs , type_word ):
380
+ """Test psa_constant_names for the specified type.
381
+
382
+ Run the program on the names for this type.
383
+ Use the inputs to figure out what arguments to pass to macros that
384
+ take arguments.
385
+ """
386
+ expressions , values = collect_values (inputs , type_word ,
387
+ include_path = self .options .include ,
388
+ keep_c = self .options .keep_c )
389
+ output = subprocess .check_output ([self .options .program , type_word ] +
390
+ values )
391
+ outputs = output .decode ('ascii' ).strip ().split ('\n ' )
392
+ self .count += len (expressions )
393
+ for expr , value , output in zip (expressions , values , outputs ):
394
+ if normalize (expr ) != normalize (output ):
395
+ self .errors .append (self .Error (type = type_word ,
396
+ expression = expr ,
397
+ value = value ,
398
+ output = output ))
399
+
400
+ def run_all (self , inputs ):
401
+ """Run psa_constant_names on all the gathered inputs."""
402
+ for type_word in ['status' , 'algorithm' , 'ecc_curve' , 'dh_group' ,
403
+ 'key_type' , 'key_usage' ]:
404
+ self .run_one (inputs , type_word )
405
+
406
+ def report (self , out ):
407
+ """Describe each case where the output is not as expected.
408
+
409
+ Write the errors to ``out``.
410
+ Also write a total.
411
+ """
412
+ for error in self .errors :
413
+ out .write ('For {} "{}", got "{}" (value: {})\n '
414
+ .format (error .type , error .expression ,
415
+ error .output , error .value ))
416
+ out .write ('{} test cases' .format (self .count ))
417
+ if self .errors :
418
+ out .write (', {} FAIL\n ' .format (len (self .errors )))
419
+ else :
420
+ out .write (' PASS\n ' )
421
+
422
+ HEADERS = ['psa/crypto.h' , 'psa/crypto_extra.h' , 'psa/crypto_values.h' ]
423
+ TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data' ]
349
424
350
425
def main ():
351
426
parser = argparse .ArgumentParser (description = globals ()['__doc__' ])
352
427
parser .add_argument ('--include' , '-I' ,
353
428
action = 'append' , default = ['include' ],
354
429
help = 'Directory for header files' )
355
- parser .add_argument ('--program' ,
356
- default = 'programs/psa/psa_constant_names' ,
357
- help = 'Program to test' )
358
430
parser .add_argument ('--keep-c' ,
359
431
action = 'store_true' , dest = 'keep_c' , default = False ,
360
432
help = 'Keep the intermediate C file' )
361
433
parser .add_argument ('--no-keep-c' ,
362
434
action = 'store_false' , dest = 'keep_c' ,
363
435
help = 'Don\' t keep the intermediate C file (default)' )
436
+ parser .add_argument ('--program' ,
437
+ default = 'programs/psa/psa_constant_names' ,
438
+ help = 'Program to test' )
364
439
options = parser .parse_args ()
365
- headers = [os .path .join (options .include [0 ], 'psa' , h )
366
- for h in ['crypto.h' , 'crypto_extra.h' , 'crypto_values.h' ]]
367
- test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data' ]
368
- inputs = gather_inputs (headers , test_suites )
369
- count , errors = run_tests (options , inputs )
370
- report_errors (errors )
371
- if errors == []:
372
- print ('{} test cases PASS' .format (count ))
373
- else :
374
- print ('{} test cases, {} FAIL' .format (count , len (errors )))
440
+ headers = [os .path .join (options .include [0 ], h ) for h in HEADERS ]
441
+ inputs = gather_inputs (headers , TEST_SUITES )
442
+ tests = Tests (options )
443
+ tests .run_all (inputs )
444
+ tests .report (sys .stdout )
445
+ if tests .errors :
375
446
exit (1 )
376
447
377
448
if __name__ == '__main__' :
0 commit comments