@@ -5548,31 +5548,86 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5548
5548
getSubsetParameters (parameterIndices, diffParams,
5549
5549
/* reverseCurryLevels*/ !makeSelfParamFirst);
5550
5550
5551
- // Get the original semantic result type .
5551
+ // Get the original non-inout semantic result types .
5552
5552
SmallVector<AutoDiffSemanticFunctionResultType, 1 > originalResults;
5553
5553
autodiff::getFunctionSemanticResultTypes (this , originalResults);
5554
5554
// Error if no original semantic results.
5555
5555
if (originalResults.empty ())
5556
5556
return llvm::make_error<DerivativeFunctionTypeError>(
5557
5557
this , DerivativeFunctionTypeError::Kind::NoSemanticResults);
5558
- // Error if multiple original semantic results.
5559
- // TODO(TF-1250): Support functions with multiple semantic results.
5560
- if (originalResults.size () > 1 )
5561
- return llvm::make_error<DerivativeFunctionTypeError>(
5562
- this , DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
5563
- auto originalResult = originalResults.front ();
5564
- auto originalResultType = originalResult.type ;
5565
-
5566
- // Get the original semantic result type's `TangentVector` associated type.
5567
- auto resultTan =
5568
- originalResultType->getAutoDiffTangentSpace (lookupConformance);
5569
- // Error if original semantic result has no tangent space.
5570
- if (!resultTan) {
5558
+ // Accumulate non-inout result tangent spaces.
5559
+ SmallVector<Type, 1 > resultTanTypes;
5560
+ bool hasInoutResult = false ;
5561
+ for (auto i : range (originalResults.size ())) {
5562
+ auto originalResult = originalResults[i];
5563
+ auto originalResultType = originalResult.type ;
5564
+ // Voids currently have a defined tangent vector, so ignore them.
5565
+ if (originalResultType->isVoid ())
5566
+ continue ;
5567
+ if (originalResult.isInout ) {
5568
+ hasInoutResult = true ;
5569
+ continue ;
5570
+ }
5571
+ // Get the original semantic result type's `TangentVector` associated type.
5572
+ auto resultTan =
5573
+ originalResultType->getAutoDiffTangentSpace (lookupConformance);
5574
+ if (!resultTan)
5575
+ continue ;
5576
+ auto resultTanType = resultTan->getType ();
5577
+ resultTanTypes.push_back (resultTanType);
5578
+ }
5579
+ // Append non-wrt inout result tangent spaces.
5580
+ // This uses the logic from getSubsetParameters(), only operating over all
5581
+ // parameter indices and looking for non-wrt indices.
5582
+ SmallVector<AnyFunctionType *, 2 > curryLevels;
5583
+ // An inlined version of unwrapCurryLevels().
5584
+ AnyFunctionType *fnTy = this ;
5585
+ while (fnTy != nullptr ) {
5586
+ curryLevels.push_back (fnTy);
5587
+ fnTy = fnTy->getResult ()->getAs <AnyFunctionType>();
5588
+ }
5589
+
5590
+ SmallVector<unsigned , 2 > curryLevelParameterIndexOffsets (curryLevels.size ());
5591
+ unsigned currentOffset = 0 ;
5592
+ for (unsigned curryLevelIndex : llvm::reverse (indices (curryLevels))) {
5593
+ curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
5594
+ currentOffset += curryLevels[curryLevelIndex]->getNumParams ();
5595
+ }
5596
+
5597
+ if (!makeSelfParamFirst) {
5598
+ std::reverse (curryLevels.begin (), curryLevels.end ());
5599
+ std::reverse (curryLevelParameterIndexOffsets.begin (),
5600
+ curryLevelParameterIndexOffsets.end ());
5601
+ }
5602
+
5603
+ for (unsigned curryLevelIndex : indices (curryLevels)) {
5604
+ auto *curryLevel = curryLevels[curryLevelIndex];
5605
+ unsigned parameterIndexOffset =
5606
+ curryLevelParameterIndexOffsets[curryLevelIndex];
5607
+ for (unsigned paramIndex : range (curryLevel->getNumParams ())) {
5608
+ if (parameterIndices->contains (parameterIndexOffset + paramIndex))
5609
+ continue ;
5610
+
5611
+ auto param = curryLevel->getParams ()[paramIndex];
5612
+ if (param.isInOut ()) {
5613
+ auto resultType = param.getPlainType ();
5614
+ if (resultType->isVoid ())
5615
+ continue ;
5616
+ auto resultTan = resultType->getAutoDiffTangentSpace (lookupConformance);
5617
+ if (!resultTan)
5618
+ continue ;
5619
+ auto resultTanType = resultTan->getType ();
5620
+ resultTanTypes.push_back (resultTanType);
5621
+ }
5622
+ }
5623
+ }
5624
+
5625
+ // Error if no semantic result has a tangent space.
5626
+ if (resultTanTypes.empty () && !hasInoutResult) {
5571
5627
return llvm::make_error<DerivativeFunctionTypeError>(
5572
5628
this , DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5573
- std::make_pair (originalResultType , /* index*/ 0 ));
5629
+ std::make_pair (originalResults. front (). type , /* index*/ 0 ));
5574
5630
}
5575
- auto resultTanType = resultTan->getType ();
5576
5631
5577
5632
// Compute the result linear map function type.
5578
5633
FunctionType *linearMapType;
@@ -5592,7 +5647,6 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5592
5647
// - Original: `(T0, inout T1, ...) -> Void`
5593
5648
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
5594
5649
SmallVector<AnyFunctionType::Param, 4 > differentialParams;
5595
- bool hasInoutDiffParameter = false ;
5596
5650
for (auto i : range (diffParams.size ())) {
5597
5651
auto diffParam = diffParams[i];
5598
5652
auto paramType = diffParam.getPlainType ();
@@ -5607,11 +5661,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5607
5661
}
5608
5662
differentialParams.push_back (AnyFunctionType::Param (
5609
5663
paramTan->getType (), Identifier (), diffParam.getParameterFlags ()));
5610
- if (diffParam.isInOut ())
5611
- hasInoutDiffParameter = true ;
5612
5664
}
5613
- auto differentialResult =
5614
- hasInoutDiffParameter ? Type (ctx.TheEmptyTupleType ) : resultTanType;
5665
+ Type differentialResult;
5666
+ if (resultTanTypes.empty ()) {
5667
+ differentialResult = ctx.TheEmptyTupleType ;
5668
+ } else if (resultTanTypes.size () == 1 ) {
5669
+ differentialResult = resultTanTypes.front ();
5670
+ } else {
5671
+ SmallVector<TupleTypeElt, 2 > differentialResults;
5672
+ for (auto i : range (resultTanTypes.size ())) {
5673
+ auto resultTanType = resultTanTypes[i];
5674
+ differentialResults.push_back (
5675
+ TupleTypeElt (resultTanType, Identifier ()));
5676
+ }
5677
+ differentialResult = TupleType::get (differentialResults, ctx);
5678
+ }
5679
+
5615
5680
// FIXME: Verify ExtInfo state is correct, not working by accident.
5616
5681
FunctionType::ExtInfo info;
5617
5682
linearMapType =
@@ -5629,11 +5694,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5629
5694
// - Original: `(T0, inout T1, ...) -> Void`
5630
5695
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
5631
5696
//
5632
- // Case 3: original function has a wrt `inout` parameter .
5633
- // - Original: `(T0, inout T1, ...) -> Void `
5634
- // - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
5697
+ // Case 3: original function has wrt `inout` parameters .
5698
+ // - Original: `(T0, inout T1, ...) -> R `
5699
+ // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
5635
5700
SmallVector<TupleTypeElt, 4 > pullbackResults;
5636
- bool hasInoutDiffParameter = false ;
5701
+ SmallVector<AnyFunctionType::Param, 2 > inoutParams ;
5637
5702
for (auto i : range (diffParams.size ())) {
5638
5703
auto diffParam = diffParams[i];
5639
5704
auto paramType = diffParam.getPlainType ();
@@ -5647,7 +5712,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5647
5712
std::make_pair (paramType, i));
5648
5713
}
5649
5714
if (diffParam.isInOut ()) {
5650
- hasInoutDiffParameter = true ;
5715
+ if (paramType->isVoid ())
5716
+ continue ;
5717
+ inoutParams.push_back (diffParam);
5651
5718
continue ;
5652
5719
}
5653
5720
pullbackResults.emplace_back (paramTan->getType ());
@@ -5660,12 +5727,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5660
5727
} else {
5661
5728
pullbackResult = TupleType::get (pullbackResults, ctx);
5662
5729
}
5663
- auto flags = ParameterTypeFlags ().withInOut (hasInoutDiffParameter);
5664
- auto pullbackParam =
5665
- AnyFunctionType::Param (resultTanType, Identifier (), flags);
5730
+ // First accumulate non-inout results as pullback parameters.
5731
+ SmallVector<FunctionType::Param, 2 > pullbackParams;
5732
+ for (auto i : range (resultTanTypes.size ())) {
5733
+ auto resultTanType = resultTanTypes[i];
5734
+ auto flags = ParameterTypeFlags ().withInOut (false );
5735
+ pullbackParams.push_back (AnyFunctionType::Param (
5736
+ resultTanType, Identifier (), flags));
5737
+ }
5738
+ // Then append inout parameters.
5739
+ for (auto i : range (inoutParams.size ())) {
5740
+ auto inoutParam = inoutParams[i];
5741
+ auto inoutParamType = inoutParam.getPlainType ();
5742
+ auto inoutParamTan =
5743
+ inoutParamType->getAutoDiffTangentSpace (lookupConformance);
5744
+ auto flags = ParameterTypeFlags ().withInOut (true );
5745
+ pullbackParams.push_back (AnyFunctionType::Param (
5746
+ inoutParamTan->getType (), Identifier (), flags));
5747
+ }
5666
5748
// FIXME: Verify ExtInfo state is correct, not working by accident.
5667
5749
FunctionType::ExtInfo info;
5668
- linearMapType = FunctionType::get ({pullbackParam} , pullbackResult, info);
5750
+ linearMapType = FunctionType::get (pullbackParams , pullbackResult, info);
5669
5751
break ;
5670
5752
}
5671
5753
}
0 commit comments