Skip to content

[AutoDiff] [Serialization] Fix '@differentiable(linear)' SIL function type serialization. #27659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4284,6 +4284,10 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
}

// SWIFT_ENABLE_TENSORFLOW
DifferentiabilityKind getDifferentiabilityKind() const {
return getExtInfo().getDifferentiabilityKind();
}

bool isDifferentiable() const {
return getExtInfo().isDifferentiable();
}
Expand Down
38 changes: 30 additions & 8 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4401,6 +4401,23 @@ Optional<swift::ParameterConvention> getActualParameterConvention(uint8_t raw) {
return None;
}

// SWIFT_ENABLE_TENSORFLOW
/// Translate from the serialization DifferentiabilityKind enumerators,
/// which are guaranteed to be stable, to the AST ones.
static Optional<swift::DifferentiabilityKind>
getActualDifferentiabilityKind(uint8_t raw) {
switch (serialization::DifferentiabilityKind(raw)) {
#define CASE(ID) \
case serialization::DifferentiabilityKind::ID: \
return swift::DifferentiabilityKind::ID;
CASE(NonDifferentiable)
CASE(Normal)
CASE(Linear)
#undef CASE
}
return None;
}

/// Translate from the serialization SILParameterDifferentiability enumerators,
/// which are guaranteed to be stable, to the AST ones.
static Optional<swift::SILParameterDifferentiability>
Expand All @@ -4412,8 +4429,8 @@ getActualSILParameterDifferentiability(uint8_t raw) {
CASE(DifferentiableOrNotApplicable)
CASE(NotDifferentiable)
}
return None;
#undef CASE
return None;
}

/// Translate from the serialization ResultConvention enumerators,
Expand Down Expand Up @@ -5010,7 +5027,7 @@ class swift::TypeDeserializer {
bool pseudogeneric = false;
bool noescape;
// SWIFT_ENABLE_TENSORFLOW
bool differentiable;
uint8_t rawDifferentiabilityKind;
bool hasErrorResult;
unsigned numParams;
unsigned numYields;
Expand All @@ -5025,7 +5042,7 @@ class swift::TypeDeserializer {
pseudogeneric,
noescape,
// SWIFT_ENABLE_TENSORFLOW
differentiable,
rawDifferentiabilityKind,
hasErrorResult,
numParams,
numYields,
Expand All @@ -5038,9 +5055,12 @@ class swift::TypeDeserializer {
= getActualSILFunctionTypeRepresentation(rawRepresentation);
if (!representation.hasValue())
MF.fatal();
auto kind = DifferentiabilityKind((unsigned)differentiable);
auto differentiabilityKind =
getActualDifferentiabilityKind(rawDifferentiabilityKind);
if (!differentiabilityKind.hasValue())
MF.fatal();
SILFunctionType::ExtInfo extInfo(*representation, pseudogeneric,
noescape, kind);
noescape, *differentiabilityKind);

// Process the coroutine kind.
auto coroutineKind = getActualSILCoroutineKind(rawCoroutineKind);
Expand All @@ -5065,7 +5085,7 @@ class swift::TypeDeserializer {
// SWIFT_ENABLE_TENSORFLOW
auto paramDiff =
swift::SILParameterDifferentiability::DifferentiableOrNotApplicable;
if (differentiable) {
if (differentiabilityKind != DifferentiabilityKind::NonDifferentiable) {
auto paramDiffOpt =
getActualSILParameterDifferentiability(rawParamDiff);
if (!paramDiffOpt) {
Expand Down Expand Up @@ -5102,7 +5122,9 @@ class swift::TypeDeserializer {

// Bounds check. FIXME: overflow
// SWIFT_ENABLE_TENSORFLOW
unsigned entriesPerParam = differentiable ? 3 : 2;
unsigned entriesPerParam =
differentiabilityKind != DifferentiabilityKind::NonDifferentiable
? 3 : 2;
if (entriesPerParam * numParams + 2 * numResults +
2 * unsigned(hasErrorResult) >
variableData.size()) {
Expand All @@ -5119,7 +5141,7 @@ class swift::TypeDeserializer {
auto rawConvention = variableData[nextVariableDataIndex++];
// SWIFT_ENABLE_TENSORFLOW
uint64_t paramDiff = 0;
if (differentiable)
if (differentiabilityKind != DifferentiabilityKind::NonDifferentiable)
paramDiff = variableData[nextVariableDataIndex++];
auto param = processParameter(typeID, rawConvention, paramDiff);
if (!param)
Expand Down
13 changes: 11 additions & 2 deletions lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 521; // remove order from 'differentiation_function' layout
const uint16_t SWIFTMODULE_VERSION_MINOR = 522; // Add SIL function type DifferentiabilityKind field

/// A standard hash seed used for all string hashes in a serialized module.
///
Expand Down Expand Up @@ -327,6 +327,15 @@ enum class ParameterConvention : uint8_t {
using ParameterConventionField = BCFixed<4>;

// SWIFT_ENABLE_TENSORFLOW
// These IDs must \em not be renumbered or reordered without incrementing the
// module version.
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};
using DifferentiabilityKindField = BCFixed<2>;

// These IDs must \em not be renumbered or reordered without incrementing
// the module version.
enum class SILParameterDifferentiability : uint8_t {
Expand Down Expand Up @@ -951,7 +960,7 @@ namespace decls_block {
BCFixed<1>, // pseudogeneric?
BCFixed<1>, // noescape?
// SWIFT_ENABLE_TENSORFLOW
BCFixed<1>, // differentiable?
DifferentiabilityKindField, // differentiability kind
BCFixed<1>, // error result?
BCVBR<6>, // number of parameters
BCVBR<5>, // number of yields
Expand Down
20 changes: 18 additions & 2 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3661,6 +3661,19 @@ static uint8_t getRawStableSILCoroutineKind(
llvm_unreachable("bad kind");
}

// SWIFT_ENABLE_TENSORFLOW
/// Translate from the AST differentiability kind enum to the Serialization enum
/// values, which are guaranteed to be stable.
static uint8_t getRawStableDifferentiabilityKind(
swift::DifferentiabilityKind kind) {
switch (kind) {
SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable)
SIMPLE_CASE(DifferentiabilityKind, Normal)
SIMPLE_CASE(DifferentiabilityKind, Linear)
}
llvm_unreachable("bad differentiability kind");
}

/// Translate from the AST ownership enum to the Serialization enum
/// values, which are guaranteed to be stable.
static uint8_t
Expand Down Expand Up @@ -4015,8 +4028,11 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
using namespace decls_block;

auto representation = fnTy->getRepresentation();
// SWIFT_ENABLE_TENSORFLOW
auto stableRepresentation =
getRawStableSILFunctionTypeRepresentation(representation);
getRawStableSILFunctionTypeRepresentation(representation);
auto stableDifferentiabilityKind =
getRawStableDifferentiabilityKind(fnTy->getDifferentiabilityKind());

SmallVector<TypeID, 8> variableData;
for (auto param : fnTy->getParameters()) {
Expand Down Expand Up @@ -4059,7 +4075,7 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
stableCoroutineKind, stableCalleeConvention,
stableRepresentation, fnTy->isPseudogeneric(), fnTy->isNoEscape(),
// SWIFT_ENABLE_TENSORFLOW
fnTy->isDifferentiable(), fnTy->hasErrorResult(),
stableDifferentiabilityKind, fnTy->hasErrorResult(),
fnTy->getParameters().size(), fnTy->getNumYields(),
fnTy->getNumResults(), S.addGenericSignatureRef(sig), variableData);

Expand Down
29 changes: 29 additions & 0 deletions test/AutoDiff/differentiable_func_type.sil
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_func_type
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_func_type
// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_func_type | %FileCheck %s

sil_stage raw

import Swift

sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
bb0(%0 : $@differentiable(linear) (Float) -> Float):
return %0 : $@differentiable(linear) (Float) -> Float
}

// CHECK-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float
// CHECK: }


sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
bb0(%0 : $@differentiable (Float) -> Float):
return %0 : $@differentiable (Float) -> Float
}

// CHECK-LABEL: sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
// CHECK: bb0([[ARG:%.*]] : $@differentiable (Float) -> Float):
// CHECK: return [[ARG]] : $@differentiable (Float) -> Float
// CHECK: }
29 changes: 29 additions & 0 deletions test/AutoDiff/differentiable_function_inst.sil
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,35 @@ sil_stage raw
import Swift
import Builtin

sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float

// CHECK-LABEL: sil @test
sil @test : $@convention(thin) () -> () {
bb0:
%0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
%1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float

// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
%3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float

// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
%5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float
%6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float

// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
%8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float

// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float

%ret = tuple ()
return %ret : $()
}

// The adjoint function emitted by the compiler. Parameter are a vector, as in
// vector-Jacobian products, and pullback values. The function is partially
// applied to a pullback struct to form a pullback, which takes a vector and
Expand Down
46 changes: 0 additions & 46 deletions test/AutoDiff/differentiable_sil_function_type_parse.sil

This file was deleted.