Skip to content

[TF] Add bullet operator (•) for matrix multiplication #17173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,14 @@ public func matmul<Scalar : Numeric>(
return Raw.matMul(left, right)
}

infix operator : MultiplicationPrecedence
infix operator : MultiplicationPrecedence

public extension Tensor where Scalar : Numeric {
@_inlineable @inline(__always)
@available(*, renamed: "matmul(_:_:)")
func dot(_ other: Tensor) -> Tensor {
return matmul(self, other)
}

/// Performs matrix multiplication between two tensors and produces the
/// result.
@_inlineable @inline(__always)
@available(*, renamed: "matmul(_:_:)")
static func ⊗ (lhs: Tensor, rhs: Tensor) -> Tensor {
return lhs.dot(rhs)
static func • (lhs: Tensor, rhs: Tensor) -> Tensor {
return matmul(lhs, rhs)
}
}

Expand Down
12 changes: 6 additions & 6 deletions test/TensorFlow/crashers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public func postdom_crash1(w1: Tensor<Float>, inputBatch: Tensor<Float>) {
// expected-warning @-2 {{'inputBatch' implicitly copied to the accelerator}}
let iterationCount = 1000
for _ in 0..<iterationCount {
_ = inputBatch w1 // expected-note 2 {{value used here}}
_ = inputBatch w1 // expected-note 2 {{value used here}}
}
}

Expand Down Expand Up @@ -91,10 +91,10 @@ public func testStraightLineXORTraining() {

// Training loop
for _ in 0..<iterationCount {
let mmul1 = inputBatch w1
let mmul1 = inputBatch w1
let l1 = mmul1 + b1
let o1 = sigmoid(l1)
let mmul2 = o1 w2
let mmul2 = o1 w2
let l2 = mmul2 + b2
let pred = sigmoid(l2)

Expand All @@ -109,15 +109,15 @@ public func testStraightLineXORTraining() {
let dL2 = dPred * pred * (1 - pred)
let dMmul2 = dL2
let dB2 = dL2
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
let dL1 = dO1 * l1 * (1 - l1)
let dMmul1 = dL1
let dB1 = dL1

// Statically detected shape mismatch!
// expected-error @+1 {{(op: 'MatMul') with input shapes: [4,2], [4,4]}}
let dW1 = inputBatch dMmul1
let dW1 = inputBatch dMmul1

// Descent
w1 -= (dW1 * learningRate)
Expand Down
14 changes: 7 additions & 7 deletions test/TensorFlow/no_copy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ struct Classifier {
var b2 = Tensor<Float>(zeros: [1, 10])

func prediction(for input: Tensor<Float>) -> Tensor<Float> {
let h1 = sigmoid(input w1 + b1)
return sigmoid(h1 w2 + b2)
let h1 = sigmoid(input w1 + b1)
return sigmoid(h1 w2 + b2)
}

mutating func train(images: Tensor<Float>, labels: Tensor<Float>,
Expand All @@ -177,17 +177,17 @@ struct Classifier {
var epochCount = epochCount
repeat {
// Forward pass
let z1 = images w1 + b1
let z1 = images w1 + b1
let h1 = sigmoid(z1)
let z2 = h1 w2 + b2
let z2 = h1 w2 + b2
let pred = sigmoid(z2)

// Backward pass
let dz2 = pred - labels
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
let db2 = dz2.sum(squeezingAxes: 0)
let dz1 = dz2.dot(w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
let dw1 = images.transposed(withPermutations: 1, 0) dz1
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
let dw1 = images.transposed(withPermutations: 1, 0) dz1
let db1 = dz1.sum(squeezingAxes: 0)

// Gradient descent
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_debuglog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TensorTests.testAllBackends("XWPlusB") {
// Shape: 2
let b = Tensor<Float>([0.5, 0.5])
// Do xW+b!
let result = x w + b
let result = x w + b
expectEqual([1, 2], result.shape)
expectEqual([12.5, 6.5], result.scalars)
}
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_xla_debuglog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ XLATests.test("XWPlusB_XLA") {
// Shape: 2
let b = Tensor<Float>([0.5, 0.5])
// Do xW+b!
let result = x w + b
let result = x w + b
expectEqual([1, 2], result.shape)
expectEqual([12.5, 6.5], result.scalars)
#endif
Expand Down