Skip to content

[AutoDiff upstream] Add differentiable_function canonicalization. #30818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2020

Conversation

dan-zheng
Copy link
Contributor

Canonicalizes differentiable_function instructions by filling in missing
derivative function operands.

Derivative function emission rules, based on the original function value:

  • function_ref: look up differentiability witness with the exact or a minimal
    superset derivative configuration. Emit a differentiability_witness_function
    for the derivative function.
  • witness_method: emit a witness_method with the minimal superset derivative
    configuration for the derivative function.
  • class_method: emit a class_method with the minimal superset derivative
    configuration for the derivative function.

If an actual emitted derivative function has a superset derivative
configuration versus the desired derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.

For differentiable_function instructions formed from curry thunk applications:
clone the curry thunk (with type (Self) -> (T, ...) -> U) and create a new
version with type (Self) -> @differentiable (T, ...) -> U.

Progress towards TF-1211.

Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.

Derivative function emission rules, based on the original function value:

- `function_ref`: look up differentiability witness with the exact or a minimal
  superset derivative configuration. Emit a `differentiability_witness_function`
  for the derivative function.
- `witness_method`: emit a `witness_method` with the minimal superset derivative
  configuration for the derivative function.
- `class_method`: emit a `class_method` with the minimal superset derivative
  configuration for the derivative function.

If an *actual* emitted derivative function has a superset derivative
configuration versus the *desired* derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.

For `differentiable_function` instructions formed from curry thunk applications:
clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new
version with type `(Self) -> @differentiable (T, ...) -> U`.

Progress towards TF-1211.
@dan-zheng dan-zheng requested review from rxwei and marcrasi April 5, 2020 23:38
Comment on lines +1140 to +1143
std::tie(thunk, interfaceSubs) =
getOrCreateSubsetParametersThunkForDerivativeFunction(
fb, origFnOperand, derivativeFn, derivativeFnKind, desiredIndices,
actualIndices);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: subset parameters thunk tests are blocked by derivative function emitters, which will be upstreamed next.

/// Check whether the given requirements are satisfied, with the given
/// derivative generic signature (containing requirements), and substitution
/// map. Returns true if error is emitted.
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo: upstream differentiation_transform_diagnostics.swift, which tests diagnoseUnsatisfiedRequirements.

@dan-zheng
Copy link
Contributor Author

@gottesmm: would you like to help review this patch, as a SILOptimizer code owner?

@dan-zheng
Copy link
Contributor Author

@swift-ci Please test

@swift-ci
Copy link
Contributor

swift-ci commented Apr 6, 2020

Build failed
Swift Test Linux Platform
Git Sha - 5430142

@dan-zheng
Copy link
Contributor Author

@swift-ci Please test Linux

@dan-zheng
Copy link
Contributor Author

Merging to unblock progress. Happy to address any feedback later!

@dan-zheng dan-zheng merged commit 146c11e into swiftlang:master Apr 6, 2020
@dan-zheng dan-zheng deleted the differentiation-transform branch April 6, 2020 03:22
Copy link
Contributor

@gottesmm gottesmm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-zheng some comments. Please fix:

Also, are these used anywhere but the differentiation pass? If not, we should just put it in the differentiation pass until we find multiple uses. A bunch of the helpers that you added need to be refactored. Please fix them. I may look at this again in a bit.

/// Emit a zero value into the given buffer access by calling
/// `AdditiveArithmetic.zero`. The given type must conform to
/// `AdditiveArithmetic`.
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix these APIs to not take a SILBuilder. Instead, it should take an insertion point and a SILBuilderContext.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems important for correctness. It can probably help clean up this code:
https://github.com/apple/swift/blob/13d5a8addbe3605984edc4ce7c6cbf1f0649a9e1/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp#L486-L497

I started looking into this but ran into some difficulties. The user above wants to insert at the end of a basic block (SILBasicBlock::iterator), but no appropriate constructor exists. Only SILBuilderWithScope(SILInstruction *I, SILBuilderContext &C) is recommended.

SILBuilderWithScope(SILBasicBlock::iterator I, SILBuilderContext &C) would work but doesn't exist. Can we add it, or are there other considerations?

@@ -33,6 +34,12 @@ namespace autodiff {
/// This is being used to print short debug messages within the AD pass.
raw_ostream &getADDebugStream();

/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
void collectAllActualResultsInTypeOrder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a number of problems with this. I am pretty sure we already have something like this.

auto diffResultFnTy = resultFnTy->getWithExtInfo(
resultFnTy->getExtInfo().withDifferentiabilityKind(
DifferentiabilityKind::Normal));
auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a common pattern here, maybe put the construction of this function onto a helper on SILOptFunctionBuilder? The hope with that was to hide the low level SILFunction creation functions in favor of higher level routines so we don't have to have huge invocations of SILFunctionType::get.

Copy link
Contributor Author

@dan-zheng dan-zheng Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here takes a $(Self) -> (T, ...) -> U function (curry thunk) and clones it into a $(Self) -> @differentiable (T, ...) -> U function.

SILOptFunctionBuilder looks very minimal, so I'm not sure adding an entry point makes sense. We should add a type calculation helper to SILFunctionType, but that seems ad-hoc because this is the only user.

Let me know if you have further suggestions!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-zheng I wrote that code and the reason why it is minimal now is that I was trying to create the scaffolding for what I just mentioned. Just no one has had the time to fix it up old uses. That being said, we should not make the problem worse.

}
}

SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? Could we put this onto SILBuilder or something? Seems wrong to have such helper utilities.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would give it a name that makes it explicit that you are using a tuple. Otherwise, how do I know I am getting back a tuple?

Copy link
Contributor Author

@dan-zheng dan-zheng Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing in #30970 (comment), let's continue discussion there.

return builder.createTuple(loc, elements);
}

void extractAllElements(SILValue value, SILBuilder &builder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is void SILBuilder::emitDestructureValueOperation(SILLocation loc, SILValue operand,
SmallVectorImpl &result);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing in #30970 (comment), let's continue discussion there.

results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
}

void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This again doesn't feel like a method that has anything to do with differentiation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this code currently emits AdditiveArithmetic.zero, but we plan to change it to emit Differentiable.zeroTangentVector sometime (TF-1005), so I'd prefer to keep it differentiation-specific.

Callers of this function could be deduped though: PullbackEmitter::emitZeroIndirect and JVPEmitter::emitZeroIndirect. I can look into that, but I'm not sure how to best use SILBuilderContext & since a SILBasicBlock::iterator is needed as an insertion point: #30818 (comment)

@marcrasi
Copy link

Also, are these used anywhere but the differentiation pass? If not, we should just put it in the differentiation pass until we find multiple uses.

Would moving all the contents of include/SILOptimizer/Utils/Differentiation and lib/SILOptimizer/Utils/Differentiation to a new directory named lib/SILOptimizer/Mandatory/Differentiation be an appropriate way to do this?

I would very much like to avoid moving everything into a single .cpp file because differentiation has many distinct bits of analysis and synthesis that are easier to comprehend separately. (Also having multiple files is really nice for incremental compilation time.)

@dan-zheng
Copy link
Contributor Author

Thanks @gottesmm! I'll create a PR addressing your feedback, we can continue discussion there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants