@@ -5549,32 +5549,43 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5549
5549
getSubsetParameters (parameterIndices, diffParams,
5550
5550
/* reverseCurryLevels*/ !makeSelfParamFirst);
5551
5551
5552
- // Get the original semantic result type .
5552
+ // Get the original non-inout semantic result types .
5553
5553
SmallVector<AutoDiffSemanticFunctionResultType, 1 > originalResults;
5554
- autodiff::getFunctionSemanticResultTypes (this , originalResults);
5554
+ autodiff::getFunctionSemanticResults (this , parameterIndices , originalResults);
5555
5555
// Error if no original semantic results.
5556
5556
if (originalResults.empty ())
5557
5557
return llvm::make_error<DerivativeFunctionTypeError>(
5558
5558
this , DerivativeFunctionTypeError::Kind::NoSemanticResults);
5559
- // Error if multiple original semantic results.
5560
- // TODO(TF-1250): Support functions with multiple semantic results.
5561
- if (originalResults.size () > 1 )
5562
- return llvm::make_error<DerivativeFunctionTypeError>(
5563
- this , DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
5564
- auto originalResult = originalResults.front ();
5565
- auto originalResultType = originalResult.type ;
5566
-
5567
- // Get the original semantic result type's `TangentVector` associated type.
5568
- auto resultTan =
5569
- originalResultType->getAutoDiffTangentSpace (lookupConformance);
5570
- // Error if original semantic result has no tangent space.
5571
- if (!resultTan) {
5572
- return llvm::make_error<DerivativeFunctionTypeError>(
5559
+
5560
+ // Accumulate non-inout result tangent spaces.
5561
+ SmallVector<Type, 1 > resultTanTypes, inoutTanTypes;
5562
+ for (auto i : range (originalResults.size ())) {
5563
+ auto originalResult = originalResults[i];
5564
+ auto originalResultType = originalResult.type ;
5565
+
5566
+ // Voids currently have a defined tangent vector, so ignore them.
5567
+ if (originalResultType->isVoid ())
5568
+ continue ;
5569
+
5570
+ // Get the original semantic result type's `TangentVector` associated type.
5571
+ // Error if a semantic result has no tangent space.
5572
+ auto resultTan =
5573
+ originalResultType->getAutoDiffTangentSpace (lookupConformance);
5574
+ if (!resultTan)
5575
+ return llvm::make_error<DerivativeFunctionTypeError>(
5573
5576
this , DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5574
- std::make_pair (originalResultType, /* index*/ 0 ));
5577
+ std::make_pair (originalResultType, unsigned (originalResult.index )));
5578
+
5579
+ if (!originalResult.isInout )
5580
+ resultTanTypes.push_back (resultTan->getType ());
5581
+ else if (originalResult.isInout && !originalResult.isWrtParam )
5582
+ inoutTanTypes.push_back (resultTan->getType ());
5575
5583
}
5576
- auto resultTanType = resultTan->getType ();
5577
5584
5585
+ // Treat non-wrt inouts as semantic results for functions returning Void
5586
+ if (resultTanTypes.empty ())
5587
+ resultTanTypes = inoutTanTypes;
5588
+
5578
5589
// Compute the result linear map function type.
5579
5590
FunctionType *linearMapType;
5580
5591
switch (kind) {
@@ -5587,32 +5598,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5587
5598
//
5588
5599
// Case 2: original function has a non-wrt `inout` parameter.
5589
5600
// - Original: `(T0, inout T1, ...) -> Void`
5590
- // - Differential: `(T0.Tan, ...) -> T1.Tan`
5601
+ // - Differential: `(T0.Tan, ...) -> T1.Tan`
5591
5602
//
5592
5603
// Case 3: original function has a wrt `inout` parameter.
5593
- // - Original: `(T0, inout T1, ...) -> Void`
5594
- // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
5604
+ // - Original: `(T0, inout T1, ...) -> Void`
5605
+ // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
5595
5606
SmallVector<AnyFunctionType::Param, 4 > differentialParams;
5596
- bool hasInoutDiffParameter = false ;
5597
5607
for (auto i : range (diffParams.size ())) {
5598
5608
auto diffParam = diffParams[i];
5599
5609
auto paramType = diffParam.getPlainType ();
5600
5610
auto paramTan = paramType->getAutoDiffTangentSpace (lookupConformance);
5601
5611
// Error if parameter has no tangent space.
5602
- if (!paramTan) {
5612
+ if (!paramTan)
5603
5613
return llvm::make_error<DerivativeFunctionTypeError>(
5604
5614
this ,
5605
5615
DerivativeFunctionTypeError::Kind::
5606
5616
NonDifferentiableDifferentiabilityParameter,
5607
5617
std::make_pair (paramType, i));
5608
- }
5618
+
5609
5619
differentialParams.push_back (AnyFunctionType::Param (
5610
5620
paramTan->getType (), Identifier (), diffParam.getParameterFlags ()));
5611
- if (diffParam.isInOut ())
5612
- hasInoutDiffParameter = true ;
5613
5621
}
5614
- auto differentialResult =
5615
- hasInoutDiffParameter ? Type (ctx.TheEmptyTupleType ) : resultTanType;
5622
+ Type differentialResult;
5623
+ if (resultTanTypes.empty ()) {
5624
+ differentialResult = ctx.TheEmptyTupleType ;
5625
+ } else if (resultTanTypes.size () == 1 ) {
5626
+ differentialResult = resultTanTypes.front ();
5627
+ } else {
5628
+ SmallVector<TupleTypeElt, 2 > differentialResults;
5629
+ for (auto i : range (resultTanTypes.size ())) {
5630
+ auto resultTanType = resultTanTypes[i];
5631
+ differentialResults.push_back (
5632
+ TupleTypeElt (resultTanType, Identifier ()));
5633
+ }
5634
+ differentialResult = TupleType::get (differentialResults, ctx);
5635
+ }
5636
+
5616
5637
// FIXME: Verify ExtInfo state is correct, not working by accident.
5617
5638
FunctionType::ExtInfo info;
5618
5639
linearMapType =
@@ -5630,25 +5651,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5630
5651
// - Original: `(T0, inout T1, ...) -> Void`
5631
5652
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
5632
5653
//
5633
- // Case 3: original function has a wrt `inout` parameter .
5634
- // - Original: `(T0, inout T1, ...) -> Void `
5635
- // - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
5654
+ // Case 3: original function has wrt `inout` parameters .
5655
+ // - Original: `(T0, inout T1, ...) -> R `
5656
+ // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
5636
5657
SmallVector<TupleTypeElt, 4 > pullbackResults;
5637
- bool hasInoutDiffParameter = false ;
5658
+ SmallVector<AnyFunctionType::Param, 2 > inoutParams ;
5638
5659
for (auto i : range (diffParams.size ())) {
5639
5660
auto diffParam = diffParams[i];
5640
5661
auto paramType = diffParam.getPlainType ();
5641
5662
auto paramTan = paramType->getAutoDiffTangentSpace (lookupConformance);
5642
5663
// Error if parameter has no tangent space.
5643
- if (!paramTan) {
5664
+ if (!paramTan)
5644
5665
return llvm::make_error<DerivativeFunctionTypeError>(
5645
5666
this ,
5646
5667
DerivativeFunctionTypeError::Kind::
5647
5668
NonDifferentiableDifferentiabilityParameter,
5648
5669
std::make_pair (paramType, i));
5649
- }
5670
+
5650
5671
if (diffParam.isInOut ()) {
5651
- hasInoutDiffParameter = true ;
5672
+ if (paramType->isVoid ())
5673
+ continue ;
5674
+ inoutParams.push_back (diffParam);
5652
5675
continue ;
5653
5676
}
5654
5677
pullbackResults.emplace_back (paramTan->getType ());
@@ -5661,12 +5684,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5661
5684
} else {
5662
5685
pullbackResult = TupleType::get (pullbackResults, ctx);
5663
5686
}
5664
- auto flags = ParameterTypeFlags ().withInOut (hasInoutDiffParameter);
5665
- auto pullbackParam =
5666
- AnyFunctionType::Param (resultTanType, Identifier (), flags);
5687
+ // First accumulate non-inout results as pullback parameters.
5688
+ SmallVector<FunctionType::Param, 2 > pullbackParams;
5689
+ for (auto i : range (resultTanTypes.size ())) {
5690
+ auto resultTanType = resultTanTypes[i];
5691
+ auto flags = ParameterTypeFlags ().withInOut (false );
5692
+ pullbackParams.push_back (AnyFunctionType::Param (
5693
+ resultTanType, Identifier (), flags));
5694
+ }
5695
+ // Then append inout parameters.
5696
+ for (auto i : range (inoutParams.size ())) {
5697
+ auto inoutParam = inoutParams[i];
5698
+ auto inoutParamType = inoutParam.getPlainType ();
5699
+ auto inoutParamTan =
5700
+ inoutParamType->getAutoDiffTangentSpace (lookupConformance);
5701
+ auto flags = ParameterTypeFlags ().withInOut (true );
5702
+ pullbackParams.push_back (AnyFunctionType::Param (
5703
+ inoutParamTan->getType (), Identifier (), flags));
5704
+ }
5667
5705
// FIXME: Verify ExtInfo state is correct, not working by accident.
5668
5706
FunctionType::ExtInfo info;
5669
- linearMapType = FunctionType::get ({pullbackParam} , pullbackResult, info);
5707
+ linearMapType = FunctionType::get (pullbackParams , pullbackResult, info);
5670
5708
break ;
5671
5709
}
5672
5710
}
0 commit comments