-
Notifications
You must be signed in to change notification settings - Fork 149
Conversation
MNIST/main.swift
Outdated
@@ -64,7 +67,7 @@ struct Classifier: Layer { | |||
var layer1b = Dense<Float>(inputSize: 128, outputSize: 10, activation: softmax) | |||
|
|||
@differentiable | |||
func applied(to input: Tensor<Float>) -> Tensor<Float> { | |||
public func call(_ input: Input) -> Output { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove public
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
MNIST/main.swift
Outdated
@@ -83,7 +86,7 @@ let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte", | |||
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10) | |||
|
|||
var classifier = Classifier() | |||
let optimizer = RMSProp<Classifier, Float>() | |||
let optimizer = RMSProp(for:classifier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let optimizer = RMSProp(for:classifier) | |
let optimizer = RMSProp(for: classifier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it!
MNIST/main.swift
Outdated
@@ -53,6 +53,9 @@ func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float> | |||
|
|||
/// A classifier. | |||
struct Classifier: Layer { | |||
public typealias Input = Tensor<Float> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove public
everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
Moves all of `stdlib/TensorFlow` from apple/swift. Friend PR: swiftlang/swift#24452.
@rxwei if this is the correct pattern, I can bang out the other ones!