Skip to content

Commit 0276f89

Browse files
selvavmIvanUkhov
andauthored
Made ggev functions as non-scalar functions and support for lapack-src v0.13 (#28)
* Made ggev functions as non-scalar functions and fixed the issue with lsame Arguments are now in lower case alpha in larfg is corrected as a scalar * Unify the rules in is_scalar * Remove redundant formatting * Bump the version number * Exclude lsame * Remove a redundancy * Update lapack-sys * Regenerate the functions * Refactor the generator * Unify the rules in is_scalar * Fix vl and vr * Fix k+ * Fix dif Co-authored-by: Ivan Ukhov <[email protected]>
1 parent 4d31631 commit 0276f89

File tree

6 files changed

+27257
-21315
lines changed

6 files changed

+27257
-21315
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
[package]
22
name = "lapack"
3-
version = "0.18.0"
3+
version = "0.19.0"
44
license = "Apache-2.0/MIT"
55
authors = [
66
"Andrew Straw <[email protected]>",
77
"Crozet Sébastien <[email protected]>",
88
"David Greenberg <[email protected]>",
99
"Ivan Ukhov <[email protected]>",
1010
"Pavel Potocek <[email protected]>",
11+
"Selvavignesh Vedamanickam <[email protected]>",
1112
"Toshiki Teramura <[email protected]>",
1213
]
1314
description = "The package provides wrappers for LAPACK (Fortran)."
@@ -26,5 +27,5 @@ version = "0.4"
2627
default-features = false
2728

2829
[dependencies.lapack-sys]
29-
version = "0.12"
30+
version = "0.14"
3031
default-features = false

bin/function.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
return_re = re.compile('(?:\s*->\s*([^;]+))?')
66

77

8-
class Function(object):
8+
class Function():
99

1010
def __init__(self, name, args, ret):
1111
self.name = name
@@ -25,8 +25,6 @@ def parse(line):
2525
arg, aty, line = pull_argument(line)
2626
if arg is None:
2727
break
28-
if arg == 'matrix_layout':
29-
arg = 'layout'
3028
args.append((arg, aty))
3129
line = line.strip()
3230

@@ -55,7 +53,7 @@ def pull_return(s):
5553
return match.group(1), s[match.end(1):]
5654

5755

58-
def read_functions(path):
56+
def read(path):
5957
lines = []
6058
with open(path) as file:
6159
append = False

bin/generate.py

Lines changed: 90 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import re
66

77
from function import Function
8-
from function import read_functions
8+
from function import read
99

1010
select_re = re.compile('LAPACK_(\w)_SELECT(\d)')
1111

1212

1313
def is_scalar(name, cty, f):
14-
return (
14+
return ( \
1515
'c_char' in cty or
1616
name in [
1717
'abnrm',
@@ -20,7 +20,6 @@ def is_scalar(name, cty, f):
2020
'anorm',
2121
'bbnrm',
2222
'colcnd',
23-
'dif',
2423
'ihi',
2524
'il',
2625
'ilo',
@@ -45,29 +44,80 @@ def is_scalar(name, cty, f):
4544
'tryrac',
4645
'vu',
4746
] or
48-
name == 'q' and 'lapack_int' in cty or
49-
not (
47+
name in [
48+
'alpha',
49+
] and (
50+
'larfg' in f.name
51+
) or
52+
name in [
53+
'dif',
54+
] and not (
55+
'tgsen' in f.name or
56+
'tgsna' in f.name
57+
) or
58+
name in [
59+
'p',
60+
] and not (
61+
'tgevc' in f.name
62+
) or
63+
name in [
64+
'q'
65+
] and (
66+
'lapack_int' in cty
67+
) or
68+
name in [
69+
'vl',
70+
'vr',
71+
] and not (
5072
'geev' in f.name or
73+
'ggev' in f.name or
74+
'hsein' in f.name or
75+
'tgevc' in f.name or
5176
'tgsna' in f.name or
77+
'trevc' in f.name or
5278
'trsna' in f.name
53-
) and name in [
54-
'vl',
55-
'vr',
56-
] or
57-
not ('tgevc' in f.name) and name in [
58-
'p',
59-
] or
60-
name.startswith('alpha') or
61-
name.startswith('beta') or
79+
) or
80+
name.startswith('k') and not (
81+
'lapmr' in f.name or
82+
'lapmt' in f.name
83+
) or
6284
name.startswith('inc') or
63-
name.startswith('k') or
6485
name.startswith('ld') or
6586
name.startswith('tol') or
6687
name.startswith('vers')
6788
)
6889

6990

70-
def translate_argument(name, cty, f):
91+
def translate_name(name):
92+
return name.lower()
93+
94+
95+
def translate_base_type(cty):
96+
cty = cty.replace('__BindgenComplex<f32>', 'lapack_complex_float')
97+
cty = cty.replace('__BindgenComplex<f64>', 'lapack_complex_double')
98+
cty = cty.replace('lapack_float_return', 'c_float')
99+
cty = cty.replace('f32', 'c_float')
100+
cty = cty.replace('f64', 'c_double')
101+
102+
if 'c_char' in cty:
103+
return 'u8'
104+
elif 'c_int' in cty:
105+
return 'i32'
106+
elif 'c_float' in cty:
107+
return 'f32'
108+
elif 'c_double' in cty:
109+
return 'f64'
110+
elif 'lapack_complex_float' in cty:
111+
return 'c32'
112+
elif 'lapack_complex_double' in cty:
113+
return 'c64'
114+
elif 'size_t' in cty:
115+
return 'size_t'
116+
117+
assert False, 'cannot translate `{}`'.format(cty)
118+
119+
120+
def translate_signature_type(name, cty, f):
71121
m = select_re.match(cty)
72122
if m is not None:
73123
if m.group(1) == 'S':
@@ -79,7 +129,7 @@ def translate_argument(name, cty, f):
79129
elif m.group(1) == 'Z':
80130
return 'Select{}C64'.format(m.group(2))
81131

82-
base = translate_type_base(cty)
132+
base = translate_base_type(cty)
83133
if '*const' in cty:
84134
if is_scalar(name, cty, f):
85135
return base
@@ -94,30 +144,6 @@ def translate_argument(name, cty, f):
94144
return base
95145

96146

97-
def translate_type_base(cty):
98-
cty = cty.replace('__BindgenComplex<f32>', 'lapack_complex_float')
99-
cty = cty.replace('__BindgenComplex<f64>', 'lapack_complex_double')
100-
cty = cty.replace('f32', 'c_float')
101-
cty = cty.replace('f64', 'c_double')
102-
103-
if 'c_char' in cty:
104-
return 'u8'
105-
elif 'c_int' in cty:
106-
return 'i32'
107-
elif 'c_float' in cty:
108-
return 'f32'
109-
elif 'c_double' in cty:
110-
return 'f64'
111-
elif 'lapack_complex_float' in cty:
112-
return 'c32'
113-
elif 'lapack_complex_double' in cty:
114-
return 'c64'
115-
elif 'size_t' in cty:
116-
return 'libc::c_ulong'
117-
118-
assert False, 'cannot translate `{}`'.format(cty)
119-
120-
121147
def translate_body_argument(name, rty):
122148
if rty.startswith('Select'):
123149
return 'transmute({})'.format(name)
@@ -154,66 +180,56 @@ def translate_body_argument(name, rty):
154180
elif rty.startswith('&mut [c'):
155181
return '{}.as_mut_ptr() as *mut _'.format(name)
156182

157-
elif rty.startswith('libc::'):
158-
return '&{}'.format(name)
183+
elif rty == 'size_t':
184+
return name
159185

160186
assert False, 'cannot translate `{}: {}`'.format(name, rty)
161187

162188

163-
def translate_return_type(cty):
164-
cty = cty.replace('lapack_float_return', 'c_float')
165-
cty = cty.replace('f64', 'c_double')
166-
167-
if cty == 'c_int':
168-
return 'i32'
169-
elif cty == 'c_float':
170-
return 'f32'
171-
elif cty == 'c_double':
172-
return 'f64'
173-
174-
assert False, 'cannot translate `{}`'.format(cty)
175-
176-
177-
def format_header(f):
178-
args = format_header_arguments(f)
189+
def format_signature(f):
190+
args = format_signature_arguments(f)
179191
if f.ret is None:
180192
return 'pub unsafe fn {}({})'.format(f.name, args)
181193
else:
182194
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args,
183-
translate_return_type(f.ret))
184-
185-
186-
def format_body(f):
187-
return 'ffi::{}_({})'.format(f.name, format_body_arguments(f))
195+
translate_base_type(f.ret))
188196

189197

190-
def format_header_arguments(f):
198+
def format_signature_arguments(f):
191199
s = []
192-
for arg in f.args:
193-
s.append('{}: {}'.format(arg[0], translate_argument(*arg, f=f)))
200+
for name, cty in f.args:
201+
name = translate_name(name)
202+
s.append('{}: {}'.format(name, translate_signature_type(name, cty, f)))
194203
return ', '.join(s)
195204

196205

206+
def format_body(f):
207+
return 'ffi::{}_({})'.format(f.name, format_body_arguments(f))
208+
209+
197210
def format_body_arguments(f):
198211
s = []
199-
for arg in f.args:
200-
rty = translate_argument(*arg, f=f)
201-
s.append(translate_body_argument(arg[0], rty))
212+
for name, cty in f.args:
213+
name = translate_name(name)
214+
rty = translate_signature_type(name, cty, f)
215+
s.append(translate_body_argument(name, rty))
202216
return ', '.join(s)
203217

204218

205-
def prepare(code):
219+
def process(code):
206220
lines = filter(lambda line: not re.match(r'^\s*//.*', line),
207221
code.split('\n'))
208222
lines = re.sub(r'\s+', ' ', ''.join(lines)).strip().split(';')
209223
lines = filter(lambda line: not re.match(r'^\s*$', line), lines)
210224
return [Function.parse(line) for line in lines]
211225

212226

213-
def do(functions):
227+
def write(functions):
214228
for f in functions:
229+
if f.name in ['lsame']:
230+
continue
215231
print('\n#[inline]')
216-
print(format_header(f) + ' {')
232+
print(format_signature(f) + ' {')
217233
print(' ' + format_body(f) + '\n}')
218234

219235

@@ -222,4 +238,4 @@ def do(functions):
222238
parser.add_argument('--sys', default='lapack-sys')
223239
arguments = parser.parse_args()
224240
path = os.path.join(arguments.sys, 'src', 'lapack.rs')
225-
do(prepare(read_functions(path)))
241+
write(process(read(path)))

0 commit comments

Comments
 (0)