Skip to content

Commit adfb820

Browse files
committed
Add equal_nan argument,
for `**_close` in stdlib_math.
1 parent 6eb25dd commit adfb820

File tree

6 files changed

+109
-49
lines changed

6 files changed

+109
-49
lines changed

doc/specs/stdlib_math.md

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ is_close(a, b, rel_tol, abs_tol) = is_close(a%re, b%re, rel_tol, abs_tol) .and.
406406

407407
#### Syntax
408408

409-
`bool = [[stdlib_math(module):is_close(interface)]] (a, b [, rel_tol, abs_tol])`
409+
`bool = [[stdlib_math(module):is_close(interface)]] (a, b [, rel_tol, abs_tol, equal_nan])`
410410

411411
#### Status
412412

@@ -424,12 +424,15 @@ This argument is `intent(in)`.
424424
`b`: Shall be a `real/complex` scalar/array.
425425
This argument is `intent(in)`.
426426

427-
`rel_tol`: Shall be a `real` scalar.
427+
`rel_tol`: Shall be a `real` scalar/array.
428428
This argument is `intent(in)` and `optional`, which is `1.0e-9` by default.
429429

430-
`abs_tol`: Shall be a `real` scalar.
430+
`abs_tol`: Shall be a `real` scalar/array.
431431
This argument is `intent(in)` and `optional`, which is `0.0` by default.
432432

433+
`equal_nan`: Shall be a `logical` scalar/array.
434+
This argument is `intent(in)` and `optional`, which is `.false.` by default.
435+
433436
Note: All `real/complex` arguments must have same `kind`.
434437
If the value of `rel_tol/abs_tol` is negative (not recommended),
435438
it will be corrected to `abs(rel_tol/abs_tol)` by the internal process of `is_close`.
@@ -442,25 +445,34 @@ Returns a `logical` scalar/array.
442445

443446
```fortran
444447
program demo_math_is_close
448+
445449
use stdlib_math, only: is_close
446450
use stdlib_error, only: check
447-
real :: x(2) = [1, 2]
448-
print *, is_close(x,[real :: 1, 2.1]) !! [T, F]
449-
print *, is_close(2.0, 2.1, abs_tol=0.1) !! T
451+
real :: x(2) = [1, 2], y, NAN
452+
453+
y = -3
454+
NAN = sqrt(y)
455+
456+
print *, is_close(x,[real :: 1, 2.1]) !! [T, F]
457+
print *, is_close(2.0, 2.1, abs_tol=0.1) !! T
458+
print *, NAN, is_close(2.0, NAN), is_close(2.0, NAN, equal_nan=.true.) !! NAN, F, F
459+
print *, is_close(NAN, NAN), is_close(NAN, NAN, equal_nan=.true.) !! F, T
460+
450461
call check(all(is_close(x, [2.0, 2.0])), msg="all(is_close(x, [2.0, 2.0])) failed.", warn=.true.)
451-
!! all(is_close(x, [2.0, 2.0])) failed.
462+
!! all(is_close(x, [2.0, 2.0])) failed.
463+
452464
end program demo_math_is_close
453465
```
454466

455467
### `all_close`
456468

457469
#### Description
458470

459-
Returns a boolean scalar where two arrays are element-wise equal within a tolerance, behaves like `all(is_close(a, b [, rel_tol, abs_tol]))`.
471+
Returns a boolean scalar where two arrays are element-wise equal within a tolerance, behaves like `all(is_close(a, b [, rel_tol, abs_tol, equal_nan]))`.
460472

461473
#### Syntax
462474

463-
`bool = [[stdlib_math(module):all_close(interface)]] (a, b [, rel_tol, abs_tol])`
475+
`bool = [[stdlib_math(module):all_close(interface)]] (a, b [, rel_tol, abs_tol, equal_nan])`
464476

465477
#### Status
466478

@@ -484,6 +496,9 @@ This argument is `intent(in)` and `optional`, which is `1.0e-9` by default.
484496
`abs_tol`: Shall be a `real` scalar.
485497
This argument is `intent(in)` and `optional`, which is `0.0` by default.
486498

499+
`equal_nan`: Shall be a `logical` scalar.
500+
This argument is `intent(in)` and `optional`, which is `.false.` by default.
501+
487502
Note: All `real/complex` arguments must have same `kind`.
488503
If the value of `rel_tol/abs_tol` is negative (not recommended),
489504
it will be corrected to `abs(rel_tol/abs_tol)` by the internal process of `all_close`.
@@ -496,18 +511,23 @@ Returns a `logical` scalar.
496511

497512
```fortran
498513
program demo_math_all_close
514+
499515
use stdlib_math, only: all_close
500516
use stdlib_error, only: check
501-
real :: x(2) = [1, 2], random(4, 4)
517+
real :: x(2) = [1, 2], y, NAN
502518
complex :: z(4, 4)
503519
520+
y = -3
521+
NAN = sqrt(y)
522+
z = (1.0, 1.0)
523+
524+
print *, all_close(z+cmplx(1.0e-11, 1.0e-11), z) !! T
525+
print *, NAN, all_close([NAN], [NAN]), all_close([NAN], [NAN], equal_nan=.true.)
526+
!! NAN, F, T
527+
504528
call check(all_close(x, [2.0, 2.0], rel_tol=1.0e-6, abs_tol=1.0e-3), &
505529
msg="all_close(x, [2.0, 2.0]) failed.", warn=.true.)
506530
!! all_close(x, [2.0, 2.0]) failed.
507-
508-
call random_number(random(4, 4))
509-
z = 1.0
510-
print *, all_close(z+1.0e-11*random, z) !! T
511531
512532
end program demo_math_all_close
513533
```

src/stdlib_math.fypp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ module stdlib_math
296296
interface is_close
297297
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
298298
#:for k1, t1 in RC_KINDS_TYPES
299-
elemental module function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol) result(close)
299+
elemental module logical function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol, equal_nan) result(close)
300300
${t1}$, intent(in) :: a, b
301301
real(${k1}$), intent(in), optional :: rel_tol, abs_tol
302-
logical :: close
302+
logical, intent(in), optional :: equal_nan
303303
end function is_close_${t1[0]}$${k1}$
304304
#:endfor
305305
end interface is_close
@@ -313,9 +313,10 @@ module stdlib_math
313313
#:set RANKS = range(1, MAXRANK + 1)
314314
#:for k1, t1 in RC_KINDS_TYPES
315315
#:for r1 in RANKS
316-
logical pure module function all_close_${r1}$_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol) result(close)
316+
logical pure module function all_close_${r1}$_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol, equal_nan) result(close)
317317
${t1}$, intent(in) :: a${ranksuffix(r1)}$, b${ranksuffix(r1)}$
318318
real(${k1}$), intent(in), optional :: rel_tol, abs_tol
319+
logical, intent(in), optional :: equal_nan
319320
end function all_close_${r1}$_${t1[0]}$${k1}$
320321
#:endfor
321322
#:endfor

src/stdlib_math_all_close.fypp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ contains
1010

1111
#:for k1, t1 in RC_KINDS_TYPES
1212
#:for r1 in RANKS
13-
logical pure module function all_close_${r1}$_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol) result(close)
13+
logical pure module function all_close_${r1}$_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol, equal_nan) result(close)
1414

1515
${t1}$, intent(in) :: a${ranksuffix(r1)}$, b${ranksuffix(r1)}$
1616
real(${k1}$), intent(in), optional :: rel_tol, abs_tol
17+
logical, intent(in), optional :: equal_nan
1718

18-
close = all(is_close(a, b, rel_tol, abs_tol))
19+
close = all(is_close(a, b, rel_tol, abs_tol, equal_nan))
1920

2021
end function all_close_${r1}$_${t1[0]}$${k1}$
2122
#:endfor

src/stdlib_math_is_close.fypp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,40 @@
22

33
submodule(stdlib_math) stdlib_math_is_close
44

5+
use, intrinsic :: ieee_arithmetic, only: ieee_is_nan
6+
implicit none
7+
58
contains
69

710
#! Determines whether the values of `a` and `b` are close.
811

912
#:for k1, t1 in REAL_KINDS_TYPES
10-
elemental module function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol) result(close)
13+
elemental module logical function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol, equal_nan) result(close)
1114
${t1}$, intent(in) :: a, b
1215
real(${k1}$), intent(in), optional :: rel_tol, abs_tol
13-
logical :: close
14-
15-
close = abs(a - b) <= max( abs(optval(rel_tol, 1.0e-9_${k1}$)*max(abs(a), abs(b))), &
16-
abs(optval(abs_tol, 0.0_${k1}$)) )
16+
logical, intent(in), optional :: equal_nan
17+
logical :: equal_nan_
18+
19+
equal_nan_ = optval(equal_nan, .false.)
20+
21+
if (ieee_is_nan(a) .or. ieee_is_nan(b)) then
22+
close = merge(.true., .false., equal_nan_ .and. ieee_is_nan(a) .and. ieee_is_nan(b))
23+
else
24+
close = abs(a - b) <= max( abs(optval(rel_tol, 1.0e-9_${k1}$)*max(abs(a), abs(b))), &
25+
abs(optval(abs_tol, 0.0_${k1}$)) )
26+
end if
1727

1828
end function is_close_${t1[0]}$${k1}$
1929
#:endfor
2030

2131
#:for k1, t1 in CMPLX_KINDS_TYPES
22-
elemental module function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol) result(close)
32+
elemental module logical function is_close_${t1[0]}$${k1}$(a, b, rel_tol, abs_tol, equal_nan) result(close)
2333
${t1}$, intent(in) :: a, b
2434
real(${k1}$), intent(in), optional :: rel_tol, abs_tol
25-
logical :: close
35+
logical, intent(in), optional :: equal_nan
2636

27-
close = is_close_r${k1}$(a%re, b%re, rel_tol, abs_tol) .and. &
28-
is_close_r${k1}$(a%im, b%im, rel_tol, abs_tol)
37+
close = is_close_r${k1}$(a%re, b%re, rel_tol, abs_tol, equal_nan) .and. &
38+
is_close_r${k1}$(a%im, b%im, rel_tol, abs_tol, equal_nan)
2939

3040
end function is_close_${t1[0]}$${k1}$
3141
#:endfor

src/tests/math/test_math_all_close.f90

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ program tester
33
use stdlib_math, only: all_close
44
use stdlib_error, only: check
55
implicit none
6-
6+
real :: y, NAN
7+
8+
y = -3
9+
NAN = sqrt(y)
10+
711
call test_math_all_close_real
812
call test_math_all_close_complex
913
print *, "All tests in `test_math_all_close` passed."
@@ -12,26 +16,29 @@ program tester
1216

1317
subroutine test_math_all_close_real
1418

15-
real :: x(4, 4), random(4, 4)
19+
real :: x(4, 4) = 1.0
1620

17-
call random_number(random)
18-
x = 1.0
21+
call check(all_close(x + 1.0e-11, x), msg="REAL: all_close(x+1.0e-11, x) failed.")
22+
call check(all_close(x + 1.0e-5, x), msg="REAL: all_close(x+1.0e-5 , x) failed. (expected)", warn=.true.)
1923

20-
call check(all_close(x+1.0e-11*random, x), msg="REAL: all_close(x+1.0e-11*random, x) failed.")
21-
call check(all_close(x+1.0e-5 *random, x), msg="REAL: all_close(x+1.0e-5 *random, x) failed.", warn=.true.)
24+
!> Tests for NAN
25+
call check(all_close(x + NAN, x), msg="REAL: all_close(x+NAN, x) failed. (expected)", warn=.true.)
26+
call check(all_close(x + NAN, x + NAN, equal_nan=.true.), msg="REAL: all_close(x+NAN, x, equal_nan=.true.) failed.")
2227

2328
end subroutine test_math_all_close_real
2429

2530
subroutine test_math_all_close_complex
2631

27-
real :: random(4, 4)
28-
complex :: x(4, 4)
32+
complex :: x(4, 4) = cmplx(1.0, 1.0)
2933

30-
call random_number(random)
31-
x = 1.0
34+
call check(all_close(x + cmplx((1.0e-15, 1.0e-15)), x), msg="CMPLX: all_close(x+cmplx(1.0e-11, 1.0e-11), x)")
35+
call check(all_close(x + cmplx(1.0e-5, 1.0e-5), x), &
36+
msg="CMPLX: all_close(x+cmplx(1.0e-5 , 1.0e-5 ), x) failed. (expected)", warn=.true.)
3237

33-
call check(all_close(x+1.0e-11*random, x), msg="CMPLX: all_close(x+1.0e-11*random, x)")
34-
call check(all_close(x+1.0e-5 *random, x), msg="CMPLX: all_close(x+1.0e-5 *random, x) failed.", warn=.true.)
38+
!> Tests for NAN
39+
call check(all_close(x + cmplx(NAN, NAN), x), msg="REAL: all_close(x+cmplx(NAN, NAN), x) failed. (expected)", warn=.true.)
40+
call check(all_close(x + cmplx(NAN, NAN), x + cmplx(NAN, NAN), equal_nan=.true.), &
41+
msg="REAL: all_close(x+cmplx(NAN, NAN), x, equal_nan=.true.) failed.")
3542

3643
end subroutine test_math_all_close_complex
3744

src/tests/math/test_math_is_close.f90

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
program test_math_is_close
22

3+
implicit none
4+
5+
real :: x, NAN
6+
x = -3
7+
NAN = sqrt(x)
8+
39
call test_math_is_close_real
410
call test_math_is_close_complex
511
print *, "All tests in `test_math_is_close` passed."
@@ -12,7 +18,7 @@ subroutine test_math_is_close_real
1218

1319
call check(is_close(2.5, 2.5, rel_tol=1.0e-5), msg="is_close(2.5, 2.5, rel_tol=1.0e-5) failed.")
1420
call check(all(is_close([2.5, 3.2], [2.5, 10.0], rel_tol=1.0e-5)), &
15-
msg="all(is_close([2.5, 3.2], [2.5, 10.0], rel_tol=1.0e-5)) failed (expected).", warn=.true.)
21+
msg="all(is_close([2.5, 3.2], [2.5, 10.0], rel_tol=1.0e-5)) failed. (expected)", warn=.true.)
1622
call check(all(is_close(reshape([2.5, 3.2, 2.2, 1.0], [2, 2]), reshape([2.5, 3.2001, 2.25, 1.1], [2, 2]), &
1723
abs_tol=1.0e-5, rel_tol=0.1)), &
1824
msg="all(is_close(reshape([2.5, 3.2, 2.2, 1.0],[2,2]), reshape([2.5, 3.2001, 2.25, 1.1],[2,2]), &
@@ -21,26 +27,41 @@ subroutine test_math_is_close_real
2127
!> Tests for zeros
2228
call check(is_close(0.0, -0.0), msg="is_close(0.0, -0.0) failed.")
2329

30+
!> Tests for NaN
31+
call check(is_close(NAN, NAN), msg="is_close(NAN, NAN) failed.", warn=.true.)
32+
call check(is_close(NAN, NAN, equal_nan=.true.), msg="is_close(NAN, NAN, equal_nan=.true.) failed.")
33+
2434
end subroutine test_math_is_close_real
2535

2636
subroutine test_math_is_close_complex
2737
use stdlib_math, only: is_close
2838
use stdlib_error, only: check
2939

30-
call check(is_close((2.5,1.2), (2.5,1.2), rel_tol=1.0e-5), &
40+
call check(is_close((2.5, 1.2), (2.5, 1.2), rel_tol=1.0e-5), &
3141
msg="is_close((2.5,1.2), (2.5,1.2), rel_tol=1.0e-5) failed.")
32-
call check(all(is_close([(2.5,1.2), (3.2,1.2)], [(2.5,1.2), (10.0,1.2)], rel_tol=1.0e-5)), &
33-
msg="all(is_close([(2.5,1.2), (3.2,1.2)], [(2.5,1.2), (10.0,1.2)], rel_tol=1.0e-5)) failed (expected).", &
42+
call check(all(is_close([(2.5, 1.2), (3.2, 1.2)], [(2.5, 1.2), (10.0, 1.2)], rel_tol=1.0e-5)), &
43+
msg="all(is_close([(2.5,1.2), (3.2,1.2)], [(2.5,1.2), (10.0,1.2)], rel_tol=1.0e-5)) failed. (expected)", &
3444
warn=.true.)
35-
call check(all(is_close(reshape([(2.5,1.2009), (3.2,1.199999)], [1, 2]), reshape([(2.4,1.2009), (3.15,1.199999)], [1, 2]), &
36-
abs_tol=1.0e-5, rel_tol=0.1)), &
37-
msg="all(is_close(reshape([(2.5,1.2009), (3.2,1.199999)], [1, 2]), &
38-
&reshape([(2.4,1.2009), (3.15,1.199999)], [1, 2]), &
39-
&rel_tol=1.0e-5, abs_tol=0.1)) failed.")
45+
call check(all(is_close(reshape([(2.5, 1.2009), (3.2, 1.199999)], [1, 2]), &
46+
reshape([(2.4, 1.2009), (3.15, 1.199999)], [1, 2]), &
47+
abs_tol=1.0e-5, rel_tol=0.1)), &
48+
msg="all(is_close(reshape([(2.5,1.2009), (3.2,1.199999)], [1, 2]), &
49+
&reshape([(2.4,1.2009), (3.15,1.199999)], [1, 2]), &
50+
&rel_tol=1.0e-5, abs_tol=0.1)) failed.")
4051

4152
!> Tests for zeros
4253
call check(is_close((0.0, -0.0), (-0.0, 0.0)), msg="is_close((0.0, -0.0), (-0.0, 0.0)) failed.")
4354

55+
!> Tests for NaN
56+
call check(is_close(cmplx(NAN, NAN), cmplx(NAN, NAN)), &
57+
msg="is_close(cmplx(NAN, NAN), cmplx(NAN, NAN)) failed. (expected)", warn=.true.)
58+
call check(is_close(cmplx(NAN, NAN), cmplx(NAN, NAN), equal_nan=.true.), &
59+
msg="is_close(cmplx(NAN, NAN), cmplx(NAN, NAN), equal_nan=.true.) failed.")
60+
call check(is_close(cmplx(NAN, 1.0), cmplx(NAN, 1.0)), &
61+
msg="is_close(cmplx(NAN, NAN), cmplx(NAN, NAN)) failed. (expected)", warn=.true.)
62+
call check(is_close(cmplx(NAN, 1.0), cmplx(NAN, 1.0), equal_nan=.true.), &
63+
msg="is_close(cmplx(NAN, NAN), cmplx(NAN, NAN), equal_nan=.ture.) failed.")
64+
4465
end subroutine test_math_is_close_complex
4566

4667
end program test_math_is_close

0 commit comments

Comments
 (0)