-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add derivative function type calculation. #29218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff upstream] Add derivative function type calculation. #29218
Conversation
Add `AnyFunctionType::getAutoDiffDerivativeFunctionType`, which returns the derivative `AnyFunctionType` for an "original" `AnyFunctionType`, given: - Differentiability parameter indices - Differentiability result index - Derivative function kind - Derivative function generic signature (optional) - Other auxiliary parameters Add doc comments explaining typing rules, preconditions, and other details. Progress towards TF-828: upstream `@differentiable` attribute type-checking.
ba4a714
to
732e0f4
Compare
@swift-ci Please smoke test |
macOS flaky test failed (passes locally)
Rerunning CI now. |
@swift-ci Please smoke test macOS platform |
Merging to unblock progress. Happy to address review feedback later. |
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance) | ||
->getType()); | ||
} else { | ||
assert(resultIndex == 0 && "resultIndex out of bounds"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: this should match the assertion below:
assert(resultIndex == 0 &&
"Expected result index 0 for non-tuple result");
SmallVector<TupleTypeElt, 2> retElts; | ||
retElts.push_back(originalResult); | ||
retElts.push_back(linearMapType); | ||
auto retTy = TupleType::get(retElts, ctx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: I prefer using the name "result type" instead of "return type". retTy
-> resultTy
// Creates an `AnyFunctionType` from the given parameters, result type, | ||
// generic signature, and `ExtInfo`. | ||
static AnyFunctionType * | ||
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a general utility for AST function types. Consider adding it to AnyFunctionType
, or just making it a closure inside getAutoDiffDerivativeFunctionType
.
Add
AnyFunctionType::getAutoDiffDerivativeFunctionType
, which returns thederivative
AnyFunctionType
for an "original"AnyFunctionType
, given:Add doc comments explaining typing rules, preconditions, and other details.
Progress towards TF-828: upstream
@differentiable
attribute type-checking.@differentiable
attribute type-checking mega-patch: #29091Tests exist on
tensorflow
branch and will be upstreamed later.Currently, not enough code has been upstreamed for meaningful testing.