Skip to content

[AutoDiff upstream] Add derivative function type calculation. #29218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

Add AnyFunctionType::getAutoDiffDerivativeFunctionType, which returns the
derivative AnyFunctionType for an "original" AnyFunctionType, given:

  • Differentiability parameter indices
  • Differentiability result index
  • Derivative function kind
  • Derivative function generic signature (optional)
  • Other auxiliary parameters

Add doc comments explaining typing rules, preconditions, and other details.


Progress towards TF-828: upstream @differentiable attribute type-checking.
@differentiable attribute type-checking mega-patch: #29091


Tests exist on tensorflow branch and will be upstreamed later.
Currently, not enough code has been upstreamed for meaningful testing.

@dan-zheng dan-zheng requested review from rxwei and DougGregor January 15, 2020 11:16
Add `AnyFunctionType::getAutoDiffDerivativeFunctionType`, which returns the
derivative `AnyFunctionType` for an "original" `AnyFunctionType`, given:
- Differentiability parameter indices
- Differentiability result index
- Derivative function kind
- Derivative function generic signature (optional)
- Other auxiliary parameters

Add doc comments explaining typing rules, preconditions, and other details.

Progress towards TF-828: upstream `@differentiable` attribute type-checking.
@dan-zheng dan-zheng force-pushed the autodiff-upstream-derivative-type branch from ba4a714 to 732e0f4 Compare January 15, 2020 11:17
@dan-zheng
Copy link
Contributor Author

@swift-ci Please smoke test

@dan-zheng
Copy link
Contributor Author

macOS flaky test failed (passes locally)

03:59:27 Failing Tests (1):
03:59:27     Swift(macosx-x86_64) :: ModuleInterface/ModuleCache/prebuilt-module-cache-forwarding.swift

Rerunning CI now.

@dan-zheng
Copy link
Contributor Author

@swift-ci Please smoke test macOS platform

@dan-zheng
Copy link
Contributor Author

Merging to unblock progress. Happy to address review feedback later.

@dan-zheng dan-zheng merged commit a1fe532 into swiftlang:master Jan 15, 2020
@dan-zheng dan-zheng deleted the autodiff-upstream-derivative-type branch January 15, 2020 17:42
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance)
->getType());
} else {
assert(resultIndex == 0 && "resultIndex out of bounds");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this should match the assertion below:

      assert(resultIndex == 0 &&
             "Expected result index 0 for non-tuple result");

SmallVector<TupleTypeElt, 2> retElts;
retElts.push_back(originalResult);
retElts.push_back(linearMapType);
auto retTy = TupleType::get(retElts, ctx);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I prefer using the name "result type" instead of "return type". retTy -> resultTy

// Creates an `AnyFunctionType` from the given parameters, result type,
// generic signature, and `ExtInfo`.
static AnyFunctionType *
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a general utility for AST function types. Consider adding it to AnyFunctionType, or just making it a closure inside getAutoDiffDerivativeFunctionType.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants