Skip to content

Commit 6d5c4d6

Browse files
authored
Flag derivative WRT map itself as not implemented in chain rule (#198)
1 parent 50f5bf9 commit 6d5c4d6

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/LinearMaps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using SparseArrays
1111

1212
import Statistics: mean
1313

14-
using ChainRulesCore: unthunk, NoTangent, @thunk
14+
using ChainRulesCore: unthunk, NoTangent, @thunk, @not_implemented
1515
import ChainRulesCore: rrule
1616

1717
using Base: require_one_based_indexing

src/chainrules.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ function rrule(::typeof(*), A::LinearMap, x::AbstractVector)
22
y = A*x
33
function pullback(dy)
44
DY = unthunk(dy)
5-
# Because A is an abstract map, the product is only differentiable w.r.t the input
6-
return NoTangent(), NoTangent(), @thunk(A' * DY)
5+
return NoTangent(), @not_implemented("Gradient with respect to linear map itself not implemented."), @thunk(A' * DY)
76
end
87
return y, pullback
98
end
@@ -12,8 +11,7 @@ function rrule(A::LinearMap, x::AbstractVector)
1211
y = A*x
1312
function pullback(dy)
1413
DY = unthunk(dy)
15-
# Because A is an abstract map, the product is only differentiable w.r.t the input
16-
return NoTangent(), @thunk(A' * DY)
14+
return @not_implemented("Gradient with respect to linear map itself not implemented."), @thunk(A' * DY)
1715
end
1816
return y, pullback
1917
end

0 commit comments

Comments
 (0)