Skip to content

Commit fe837d7

Browse files
committed
[DOCS] Minor updates to the website.
1 parent 3e10953 commit fe837d7

File tree

4 files changed

+128
-122
lines changed

4 files changed

+128
-122
lines changed

README.md

Lines changed: 94 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ the functionality provided by the official Python API, while at the same type be
1717
features. It is a work in progress and a project I started working on for my personal research purposes. Much of the API
1818
should be relatively stable by now, but things are still likely to change.
1919

20-
[![Chat Room](https://img.shields.io/badge/chat-examples-ed1965.svg?longCache=true&style=flat-square&logo=gitter)](https://gitter.im/eaplatanios/tensorflow_scala?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
20+
[![Chat Room](https://img.shields.io/badge/chat-gitter-ed1965.svg?longCache=true&style=flat-square&logo=gitter)](https://gitter.im/eaplatanios/tensorflow_scala?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
21+
2122
Please refer to the main website for documentation and tutorials. Here
2223
are a few useful links:
2324

@@ -28,97 +29,98 @@ are a few useful links:
2829

2930
## Main Features
3031

31-
- Easy manipulation of tensors and computations involving tensors (similar to NumPy in Python):
32-
33-
```scala
34-
val t1 = Tensor( 1.2, 4.5)
35-
val t2 = Tensor(-0.2, 1.1)
36-
t1 + t2 == Tensor(1.0, 5.6)
37-
```
38-
39-
- High-level API for creating, training, and using neural networks. For example, the following code shows how simple it
40-
is to train a multi-layer perceptron for MNIST using TensorFlow for Scala. Here we omit a lot of very powerful
41-
features such as summary and checkpoint savers, for simplicity, but these are also very simple to use.
42-
43-
```scala
44-
import org.platanios.tensorflow.api._
45-
import org.platanios.tensorflow.api.tf.learn._
46-
import org.platanios.tensorflow.api.ops.training.optimizers.GradientDescent
47-
import org.platanios.tensorflow.data.image.MNISTLoader
48-
49-
// Load and batch data using pre-fetching.
50-
val dataSet = MNISTLoader.load(Paths.get("/tmp"))
51-
val trainImages = tf.data.TensorSlicesDataset(dataSet.trainImages)
52-
val trainLabels = tf.data.TensorSlicesDataset(dataSet.trainLabels)
53-
val trainData =
54-
trainImages.zip(trainLabels)
55-
.repeat()
56-
.shuffle(10000)
57-
.batch(256)
58-
.prefetch(10)
59-
60-
// Create the MLP model.
61-
val input = Input(UINT8, Shape(-1, 28, 28))
62-
val trainInput = Input(UINT8, Shape(-1))
63-
val layer = Flatten("Input/Flatten") >> Cast(FLOAT32) >>
64-
Linear("Layer0", 128) >> ReLU("Layer0/Activation", 0.1f) >>
65-
Linear("Layer1", 64) >> ReLU("Layer1/Activation", 0.1f) >>
66-
Linear("Layer2", 32) >> ReLU("Layer2/Activation", 0.1f) >>
67-
Linear("OutputLayer", 10)
68-
val trainingInputLayer = Cast("TrainInput/Cast", INT64)
69-
val loss = SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> Mean("Loss/Mean")
70-
val optimizer = GradientDescent(1e-6)
71-
val model = Model(input, layer, trainInput, trainingInputLayer, loss, optimizer)
72-
73-
// Create an estimator and train the model.
74-
val estimator = Estimator(model)
75-
estimator.train(() => trainData, StopCriteria(maxSteps = Some(1000000)))
76-
```
77-
78-
And by changing a few lines to the following code, you can get checkpoint capability, summaries, and seamless
79-
integration with TensorBoard:
80-
81-
```scala
82-
loss = loss >> tf.learn.ScalarSummary("Loss/Summary", "Loss") // Collect loss summaries for plotting
83-
val summariesDir = Paths.get("/tmp/summaries") // Directory in which to save summaries and checkpoints
84-
val estimator = Estimator(model, Configuration(Some(summariesDir)))
85-
estimator.train(
86-
trainData, StopCriteria(maxSteps = Some(1000000)),
87-
Seq(
88-
SummarySaverHook(summariesDir, StepHookTrigger(100)), // Save summaries every 1000 steps
89-
CheckpointSaverHook(summariesDir, StepHookTrigger(1000))), // Save checkpoint every 1000 steps
90-
tensorBoardConfig = TensorBoardConfig(summariesDir)) // Launch TensorBoard server in the background
91-
```
92-
93-
If you now browse to `https://127.0.0.1:6006` while training, you can see the training progress:
94-
95-
<img src="https://eaplatanios.github.io/tensorflow_scala/img/tensorboard_mnist_example_plot.png" alt="tensorboard_mnist_example_plot" width="600px">
96-
97-
- Low-level graph construction API, similar to that of the Python API, but strongly typed wherever possible:
98-
99-
```scala
100-
import org.platanios.tensorflow.api._
101-
102-
val inputs = tf.placeholder(FLOAT32, Shape(-1, 10))
103-
val outputs = tf.placeholder(FLOAT32, Shape(-1, 10))
104-
val predictions = tf.createWith(nameScope = "Linear") {
105-
val weights = tf.variable("weights", FLOAT32, Shape(10, 1), tf.zerosInitializer)
106-
tf.matmul(inputs, weights)
107-
}
108-
val loss = tf.sum(tf.square(predictions - outputs))
109-
val optimizer = tf.train.AdaGrad(1.0)
110-
val trainOp = optimizer.minimize(loss)
111-
```
112-
113-
- Numpy-like indexing/slicing for tensors. For example:
114-
115-
```scala
116-
tensor(2 :: 5, ---, 1) // is equivalent to numpy's 'tensor[2:5, ..., 1]'
117-
```
118-
119-
- Efficient interaction with the native library that avoids unnecessary copying of data. All tensors are created and
120-
managed by the native TensorFlow library. When they are passed to the Scala API (e.g., fetched from a TensorFlow session), we use a combination of weak references and a disposing thread running in the background. Please refer to
121-
`tensorflow/src/main/scala/org/platanios/tensorflow/api/utilities/Disposer.scala`, for the implementation.
32+
- Easy manipulation of tensors and computations involving tensors (similar to NumPy in Python):
33+
34+
```scala
35+
val t1 = Tensor(1.2, 4.5)
36+
val t2 = Tensor(-0.2, 1.1)
37+
t1 + t2 == Tensor(1.0, 5.6)
38+
```
39+
40+
- Low-level graph construction API, similar to that of the Python API, but strongly typed wherever possible:
41+
42+
```scala
43+
val inputs = tf.placeholder[Float](Shape(-1, 10))
44+
val outputs = tf.placeholder[Float](Shape(-1, 10))
45+
val predictions = tf.nameScope("Linear") {
46+
val weights = tf.variable[Float]("weights", Shape(10, 1), tf.ZerosInitializer)
47+
tf.matmul(inputs, weights)
48+
}
49+
val loss = tf.sum(tf.square(predictions - outputs))
50+
val optimizer = tf.train.AdaGrad(1.0f)
51+
val trainOp = optimizer.minimize(loss)
52+
```
53+
54+
- Numpy-like indexing/slicing for tensors. For example:
55+
56+
```scala
57+
tensor(2 :: 5, ---, 1) // is equivalent to numpy's 'tensor[2:5, ..., 1]'
58+
```
59+
60+
- High-level API for creating, training, and using neural networks. For example, the following code shows how simple it
61+
is to train a multi-layer perceptron for MNIST using TensorFlow for Scala. Here we omit a lot of very powerful
62+
features such as summary and checkpoint savers, for simplicity, but these are also very simple to use.
63+
64+
```scala
65+
// Load and batch data using pre-fetching.
66+
val dataset = MNISTLoader.load(Paths.get("/tmp"))
67+
val trainImages = tf.data.datasetFromTensorSlices(dataset.trainImages.toFloat)
68+
val trainLabels = tf.data.datasetFromTensorSlices(dataset.trainLabels.toLong)
69+
val trainData =
70+
trainImages.zip(trainLabels)
71+
.repeat()
72+
.shuffle(10000)
73+
.batch(256)
74+
.prefetch(10)
75+
76+
// Create the MLP model.
77+
val input = Input(FLOAT32, Shape(-1, 28, 28))
78+
val trainInput = Input(INT64, Shape(-1))
79+
val layer = Flatten[Float]("Input/Flatten") >>
80+
Linear[Float]("Layer_0", 128) >> ReLU[Float]("Layer_0/Activation", 0.1f) >>
81+
Linear[Float]("Layer_1", 64) >> ReLU[Float]("Layer_1/Activation", 0.1f) >>
82+
Linear[Float]("Layer_2", 32) >> ReLU[Float]("Layer_2/Activation", 0.1f) >>
83+
Linear[Float]("OutputLayer", 10)
84+
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
85+
Mean("Loss/Mean")
86+
val optimizer = tf.train.GradientDescent(1e-6f)
87+
val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)
88+
89+
// Create an estimator and train the model.
90+
val estimator = InMemoryEstimator(model)
91+
estimator.train(() => trainData, StopCriteria(maxSteps = Some(1000000)))
92+
```
93+
94+
And by changing a few lines to the following code, you can get checkpoint capability, summaries, and seamless
95+
integration with TensorBoard:
96+
97+
```scala
98+
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
99+
Mean("Loss/Mean") >>
100+
ScalarSummary(name = "Loss", tag = "Loss")
101+
val summariesDir = Paths.get("/tmp/summaries")
102+
val estimator = InMemoryEstimator(
103+
modelFunction = model,
104+
configurationBase = Configuration(Some(summariesDir)),
105+
trainHooks = Set(
106+
SummarySaver(summariesDir, StepHookTrigger(100)),
107+
CheckpointSaver(summariesDir, StepHookTrigger(1000))),
108+
tensorBoardConfig = TensorBoardConfig(summariesDir))
109+
estimator.train(() => trainData, StopCriteria(maxSteps = Some(100000)))
110+
```
111+
112+
If you now browse to `https://127.0.0.1:6006` while training, you can see the training progress:
113+
114+
<img src="assets/images/tensorboard_mnist_example_plot.png" alt="tensorboard_mnist_example_plot" width="600px">
115+
116+
- Efficient interaction with the native library that avoids unnecessary copying of data. All tensors are created and
117+
managed by the native TensorFlow library. When they are passed to the Scala API (e.g., fetched from a TensorFlow
118+
session), we use a combination of weak references and a disposing thread running in the background. Please refer to
119+
`tensorflow/src/main/scala/org/platanios/tensorflow/api/utilities/Disposer.scala`, for the implementation.
120+
121+
## Tutorials
122+
123+
- [Object Detection using Pre-Trained Models](https://brunk.io/deep-learning-in-scala-part-3-object-detection.html)
122124

123125
## Funding
124126

docs/src/main/paradox/guides/tensors.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ define and run computations involving tensors. A tensor is
55
a generalization of vectors and matrices to potentially
66
higher dimensions. Internally, TensorFlow represents
77
tensors as `n`-dimensional arrays of some underlying data
8-
type. A @scaladoc[Tensor](org.platanios.tensorflow.api.Tensor)
9-
has a @scaladoc[DataType](org.platanios.tensorflow.api.DataType)
8+
type. A @scaladoc[Tensor](org.platanios.tensorflow.api.tensors.Tensor)
9+
has a @scaladoc[DataType](org.platanios.tensorflow.api.core.types.DataType)
1010
(e.g., `FLOAT32`, which corresponds to 32-bit floating
1111
point numbers) and a
12-
@scaladoc[Shape](org.platanios.tensorflow.api.Shape) (that
12+
@scaladoc[Shape](org.platanios.tensorflow.api.core.Shape) (that
1313
is, the number of dimensions it has and the size of each
1414
dimension -- e.g., `Shape(10, 2)` which corresponds to a
1515
matrix with 10 rows and 2 columns) associated with it. Each
1616
element in the
17-
@scaladoc[Tensor](org.platanios.tensorflow.api.Tensor) has
17+
@scaladoc[Tensor](org.platanios.tensorflow.api.tensors.Tensor) has
1818
the same data type. For example, the following code creates
1919
an integer tensor filled with zeros with shape `[2, 5]`
2020
(i.e., a two-dimensional array holding integer values, where the
@@ -29,7 +29,7 @@ You can print the contents of a tensor as follows:
2929
## Tensor Creation
3030

3131
Tensors can be created using various constructors defined in
32-
the @scaladoc[Tensor](org.platanios.tensorflow.api.Tensor)
32+
the @scaladoc[Tensor](org.platanios.tensorflow.api.tensors.Tensor)
3333
companion object. For example:
3434

3535
@@snip [Tensors.scala](/docs/src/main/scala/Tensors.scala) { #tensor_creation_examples }
@@ -40,7 +40,7 @@ As already mentioned, tensors have a data type. Various
4040
numeric data types are supported, as well as strings (i.e.,
4141
tensors containing strings are supported). It is not
4242
possible to have a
43-
@scaladoc[Tensor](org.platanios.tensorflow.api.Tensor) with
43+
@scaladoc[Tensor](org.platanios.tensorflow.api.tensors.Tensor) with
4444
more than one data type. It is possible, however, to
4545
serialize arbitrary data structures as strings and store
4646
those in tensors.
@@ -90,7 +90,7 @@ A tensor's data type can be inspected using:
9090

9191
In general, all tensor-supported operations can be accessed
9292
as direct methods/operators of the
93-
@scaladoc[Tensor](org.platanios.tensorflow.api.Tensor)
93+
@scaladoc[Tensor](org.platanios.tensorflow.api.tensors.Tensor)
9494
object, or as static methods defined in the
9595
@scaladoc[tfi](org.platanios.tensorflow.api.tfi) package,
9696
which stands for *TensorFlow Imperative*

docs/src/main/paradox/index.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ the functionality provided by the official Python API, while at the same type be
1313
features. It is a work in progress and a project I started working on for my personal research purposes. Much of the API
1414
should be relatively stable by now, but things are still likely to change.
1515

16-
[![Chat Room](https://img.shields.io/badge/chat-examples-ed1965.svg?longCache=true&style=flat-square&logo=gitter)](https://gitter.im/eaplatanios/tensorflow_scala?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
16+
[![Chat Room](https://img.shields.io/badge/chat-gitter-ed1965.svg?longCache=true&style=flat-square&logo=gitter)](https://gitter.im/eaplatanios/tensorflow_scala?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
1717

1818
@@@index
1919

@@ -72,3 +72,5 @@ Funding for the development of this library has been generously provided by the
7272
|:---------------------------------------:|:---------------------------------:|:-----------------------------------------------:|
7373
| awarded to Emmanouil Antonios Platanios | Grant #: IIS1250956 | Grant #: FA95501710218 |
7474
|<img src="assets/images/cmu_logo.svg" height="113px" width="150px" />|<img src="assets/images/nsf_logo.svg" height="150px" width="150px" />|<img src="assets/images/afosr_logo.gif" height="150px" width="150px" />|
75+
76+
TensorFlow, the TensorFlow logo, and any related marks are trademarks of Google Inc.

docs/src/main/scala/Index.scala

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,36 @@ import org.platanios.tensorflow.data.image.MNISTLoader
44

55
import java.nio.file.Paths
66

7-
trait Index {
7+
trait IndexTensorsExample {
88
// #tensors_example
99
val t1 = Tensor(1.2, 4.5)
1010
val t2 = Tensor(-0.2, 1.1)
1111
t1 + t2 == Tensor(1.0, 5.6)
1212
// #tensors_example
13+
}
14+
15+
trait IndexLowLevelExample {
16+
// #low_level_example
17+
val inputs = tf.placeholder[Float](Shape(-1, 10))
18+
val outputs = tf.placeholder[Float](Shape(-1, 10))
19+
val predictions = tf.nameScope("Linear") {
20+
val weights = tf.variable[Float]("weights", Shape(10, 1), tf.ZerosInitializer)
21+
tf.matmul(inputs, weights)
22+
}
23+
val loss = tf.sum(tf.square(predictions - outputs))
24+
val optimizer = tf.train.AdaGrad(1.0f)
25+
val trainOp = optimizer.minimize(loss)
26+
// #low_level_example
27+
}
28+
29+
trait IndexSliceExample {
30+
val tensor = Tensor.zeros[Float](Shape(10, 2, 3, 4, 5, 20))
31+
// #slice_example
32+
tensor(2 :: 5, ---, 1) // is equivalent to numpy's 'tensor[2:5, ..., 1]'
33+
// #slice_example
34+
}
1335

36+
trait IndexMNISTExample {
1437
// #mnist_example
1538
// Load and batch data using pre-fetching.
1639
val dataset = MNISTLoader.load(Paths.get("/tmp"))
@@ -58,24 +81,3 @@ trait IndexTensorBoard extends Index {
5881
estimator.train(() => trainData, StopCriteria(maxSteps = Some(100000)))
5982
// #tensorboard_example
6083
}
61-
62-
trait IndexLowLevelExample {
63-
// #low_level_example
64-
val inputs = tf.placeholder[Float](Shape(-1, 10))
65-
val outputs = tf.placeholder[Float](Shape(-1, 10))
66-
val predictions = tf.nameScope("Linear") {
67-
val weights = tf.variable[Float]("weights", Shape(10, 1), tf.ZerosInitializer)
68-
tf.matmul(inputs, weights)
69-
}
70-
val loss = tf.sum(tf.square(predictions - outputs))
71-
val optimizer = tf.train.AdaGrad(1.0f)
72-
val trainOp = optimizer.minimize(loss)
73-
// #low_level_example
74-
}
75-
76-
trait IndexSliceExample {
77-
val tensor = Tensor.zeros[Float](Shape(10, 2, 3, 4, 5, 20))
78-
// #slice_example
79-
tensor(2 :: 5, ---, 1) // is equivalent to numpy's 'tensor[2:5, ..., 1]'
80-
// #slice_example
81-
}

0 commit comments

Comments
 (0)