-
Notifications
You must be signed in to change notification settings - Fork 137
Added support for the 'log1mexp' op and its VJP. #147
Conversation
Co-Authored-By: Richard Wei <[email protected]>
@rxwei this is ready other than the function not being differentiable. If we figure out a |
Let's add a global |
@rxwei I implemented that, but I get the following error:
This appears when I make the following function differentiable: @inlinable
@differentiable
public func log1mexp<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
let isTooSmall = withoutDerivative(at: -x .< T(log(2.0)))
// This `replacing` will ultimately be a no-op because we will not select this code-path
// whenever we use the surrogate `-Tensor(onesLike: x)`.
let xSafe = x.replacing(with: -Tensor(onesLike: x), where: isTooSmall)
return log1p(-exp(xSafe)).replacing(with: log(-expm1(x)), where: isTooSmall)
} |
How is it defined? |
@rxwei It's just this: @_semantics("autodiff.nonvarying")
public func withoutDerivative<T>(at x: T) -> T {
return x
} I just pushed the changes so you can see them. |
@rxwei I made a small reproducible example. I cannot make the following function differentiable: @inlinable
@differentiable
public func example<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
let ones = withoutDerivative(at: Tensor<T>(ones: x.shape))
return x.replacing(with: ones, where: Tensor<Bool>(true))
} I get a similar stacktrace as above and I believe that this error does not have to do with @inlinable
@differentiable
public func example<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
return x.replacing(with: Tensor<T>(ones: x.shape), where: Tensor<Bool>(true))
} |
Using // FIXME: The logic for resolving `assocRef` does not reapply function
// conversions, which is problematic if `assocFn` is a `partial_apply`
// instruction.
SILValue assocRef;
if (auto *assocFnRef =
peerThroughFunctionConversions<FunctionRefInst>(assocFn)) {
auto *assoc = assocFnRef->getReferencedFunctionOrNull();
assocRef = builder.createFunctionRef(loc, assoc);
} else if (auto *assocMethodInst =
peerThroughFunctionConversions<WitnessMethodInst>(assocFn)) {
assocRef = builder.createWitnessMethod(
loc, assocMethodInst->getLookupType(),
assocMethodInst->getConformance(), assocMethodInst->getMember(),
thunk->mapTypeIntoContext(assocMethodInst->getType()));
}
assert(assocRef && "Expected associated function to be resolved"); Confirmed this is related to this FIXME. cc @dan-zheng |
Disregard my previous response. I think what is going on here is that the subset parameters thunk generation logic is not handling cases where the original JVP/JVP is already producing a subset-parameters linear map. A proper fix would be to pass in a index subset with respect to the linear map |
I see. Could you please explain a bit more detail what the index subsets represent and what the linear map is exactly? That will help me a lot in understanding what’s going on and attempting a fix. |
I have a PR that will fix this. |
Fixed in swiftlang/swift#25699. |
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Richard thanks a lot for debugging this and helping me also understand what's going on. |
…25699) The subset parameters thunk generation logic is not handling cases where the original JVP/JVP is already producing a subset-parameters linear map. This patch fixes it by computing indices w.r.t. the linear map from indices w.r.t. the original function. Resolves [TF-594](https://bugs.swift.org/browse/TF-594) and unblocks tensorflow/swift-apis#147.
It looks like kokoro does not re-run tests even though I added the |
Nvm, it just started running them now. |
@rxwei the CI seems to still be using a toolchain before the auto-diff fix. Could you please re-run tests once the CI toolchain is updated? |
@bgogul, how long does a new toolchain take to get to Kokoro? A new one was built last night so I assumed that it had propagated to CI. |
Sorry I believe the toolchain was updated after all but this PR was not. I will merge with master and re-run tests. |
This requires #145 and #146.