Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[Layer] 'func call' -> 'func callAsFunction'. #172

Merged
merged 1 commit into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct Model: Layer {
var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3, activation: identity)

@differentiable
func call(_ input: Tensor<Float>) -> Tensor<Float> {
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
return input.sequenced(through: layer1, layer2, layer3)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public protocol Layer: Differentiable & KeyPathIterable
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
func call(_ input: Input) -> Output
func callAsFunction(_ input: Input) -> Output
}

public extension Layer {
Expand Down
8 changes: 4 additions & 4 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer `[batchCount, width, inputChannels]`.
/// - Returns: The output `[batchCount, newWidth, outputChannels]`.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let conv2D = input.expandingShape(at: 1).convolved2D(
withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding)
return activation(conv2D.squeezingShape(at: 1) + bias)
Expand Down Expand Up @@ -177,7 +177,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(input.convolved2D(withFilter: filter,
strides: (1, strides.0, strides.1, 1),
padding: padding) + bias)
Expand Down Expand Up @@ -293,7 +293,7 @@ public struct Conv3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(input.convolved3D(withFilter: filter,
strides: (1, strides.0, strides.1, strides.2, 1),
padding: padding) + bias)
Expand Down Expand Up @@ -411,7 +411,7 @@ public struct TransposedConv2D: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Float>) -> Tensor<Float> {
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let batchSize = input.shape[0]
let w = (input.shape[1] - (1 * paddingIndex)) *
strides.0 + (filter.shape[0] * paddingIndex)
Expand Down
8 changes: 4 additions & 4 deletions Sources/TensorFlow/Layers/Core.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable(vjp: _vjpApplied(to:))
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
switch Context.local.learningPhase {
case .training:
return applyingTraining(to: input)
Expand Down Expand Up @@ -92,7 +92,7 @@ public struct Flatten<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let remaining = input.shape[1..<input.rank].contiguousSize
return input.reshaped(to: [batchSize, remaining])
Expand Down Expand Up @@ -128,7 +128,7 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.reshaped(toShape: shape)
}
}
Expand Down Expand Up @@ -163,7 +163,7 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(matmul(input, weight) + bias)
}
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/TensorFlow/Layers/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable(vjp: _vjpApplied(to:))
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
switch Context.local.learningPhase {
case .training:
return applyingTraining(to: input)
Expand Down Expand Up @@ -185,7 +185,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let mean = input.mean(alongAxes: axis)
let variance = input.variance(alongAxes: axis)
let inv = rsqrt(variance + epsilon) * scale
Expand Down
24 changes: 12 additions & 12 deletions Sources/TensorFlow/Layers/Pooling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public struct MaxPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.expandingShape(at: 1).maxPooled2D(
kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding
).squeezingShape(at: 1)
Expand Down Expand Up @@ -77,7 +77,7 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.maxPooled2D(
kernelSize: poolSize, strides: strides, padding: padding)
}
Expand Down Expand Up @@ -124,7 +124,7 @@ public struct MaxPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.maxPooled3D(kernelSize: poolSize, strides: strides, padding: padding)
}
}
Expand Down Expand Up @@ -184,7 +184,7 @@ public struct AvgPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.expandingShape(at: 1).averagePooled2D(
kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding
).squeezingShape(at: 1)
Expand Down Expand Up @@ -218,7 +218,7 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.averagePooled2D(kernelSize: poolSize, strides: strides, padding: padding)
}
}
Expand Down Expand Up @@ -264,7 +264,7 @@ public struct AvgPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.averagePooled3D(kernelSize: poolSize, strides: strides, padding: padding)
}
}
Expand Down Expand Up @@ -304,7 +304,7 @@ public struct GlobalAvgPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.mean(squeezingAxes: 1)
}
}
Expand All @@ -320,7 +320,7 @@ public struct GlobalAvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.mean(squeezingAxes: [1, 2])
}
}
Expand All @@ -336,7 +336,7 @@ public struct GlobalAvgPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.mean(squeezingAxes: [1, 2, 3])
}
}
Expand All @@ -355,7 +355,7 @@ public struct GlobalMaxPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// phase.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: 1)
}
}
Expand All @@ -371,7 +371,7 @@ public struct GlobalMaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: [1, 2])
}
}
Expand All @@ -387,7 +387,7 @@ public struct GlobalMaxPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: [1, 2, 3])
}
}
10 changes: 5 additions & 5 deletions Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public extension RNNCell {
/// - previousState: The previous state of the RNN cell.
/// - Returns: The output.
@differentiable
func call(input: TimeStepInput, state: State) -> RNNCellOutput<TimeStepOutput, State> {
func callAsFunction(input: TimeStepInput, state: State) -> RNNCellOutput<TimeStepOutput, State> {
return self(RNNCellInput(input: input, state: state))
}
}
Expand Down Expand Up @@ -113,7 +113,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
/// - Parameter input: The input to the layer.
/// - Returns: The hidden state.
@differentiable
public func call(_ input: Input) -> Output {
public func callAsFunction(_ input: Input) -> Output {
let concatenatedInput = input.input.concatenated(with: input.state.value, alongAxis: 1)
let newState = State(tanh(matmul(concatenatedInput, weight) + bias))
return Output(output: newState, state: newState)
Expand Down Expand Up @@ -175,7 +175,7 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric
/// - Parameter input: The input to the layer.
/// - Returns: The hidden state.
@differentiable
public func call(_ input: Input) -> Output {
public func callAsFunction(_ input: Input) -> Output {
let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1)

let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias)
Expand Down Expand Up @@ -203,7 +203,7 @@ public struct RNN<Cell: RNNCell>: Layer {
}

@differentiable(wrt: (self, input), vjp: _vjpCall(_:initialState:))
public func call(_ input: [Cell.TimeStepInput],
public func callAsFunction(_ input: [Cell.TimeStepInput],
initialState: Cell.State) -> [Cell.TimeStepOutput] {
var currentHiddenState = initialState
var timeStepOutputs: [Cell.TimeStepOutput] = []
Expand Down Expand Up @@ -253,7 +253,7 @@ public struct RNN<Cell: RNNCell>: Layer {
}

@differentiable(wrt: (self, inputs))
public func call(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
return self(inputs, initialState: cell.zeroState.withoutDerivative())
}

Expand Down
6 changes: 3 additions & 3 deletions Sources/TensorFlow/Layers/Upsampling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public struct UpSampling1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let shape = input.shape
let (batchSize, timesteps, channels) = (shape[0], shape[1], shape[2])
let scaleOnes = Tensor<Scalar>(ones: [1, 1, size, 1])
Expand All @@ -55,7 +55,7 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let shape = input.shape
let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3])
let scaleOnes = Tensor<Scalar>(ones: [1, 1, size, 1, size, 1])
Expand Down Expand Up @@ -107,7 +107,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
var result = repeatingElements(input, alongAxis: 1, count: size)
result = repeatingElements(result, alongAxis: 2, count: size)
result = repeatingElements(result, alongAxis: 3, count: size)
Expand Down
2 changes: 1 addition & 1 deletion Tests/TensorFlowTests/SequentialTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ final class SequentialTests: XCTestCase {
seed: (0xeffeffe, 0xfffe))

@differentiable
func call(_ input: Tensor<Float>) -> Tensor<Float> {
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
return input.sequenced(through: dense1, dense2)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/TensorFlowTests/TrivialModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ final class TrivialModelTests: XCTestCase {
)
}
@differentiable
func call(_ input: Tensor<Float>) -> Tensor<Float> {
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let h1 = l1(input)
return l2(h1)
}
Expand Down