Skip to content

Commit ee3a621

Browse files
authored
Merge pull request #1 from intel-ai-tce/Updating-JAX-README.md
Initial draft for JAX README.md
2 parents 47c231e + f843ecf commit ee3a621

File tree

4 files changed

+77
-181
lines changed

4 files changed

+77
-181
lines changed

AI-and-Analytics/Getting-Started-Samples/IntelJAX_GettingStarted/README.md

Lines changed: 63 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,48 @@
1-
# `TensorFlow* Getting Started` Sample
1+
# `JAX Getting Started` Sample
22

3-
The `TensorFlow* Getting Started` sample demonstrates how to train a TensorFlow* model and run inference on Intel® hardware.
3+
The `JAX Getting Started` sample demonstrates how to train a JAX model and run inference on Intel® hardware.
44
| Property | Description
55
|:--- |:---
66
| Category | Get Start Sample
7-
| What you will learn | How to start using TensorFlow* on Intel® hardware.
7+
| What you will learn | How to start using JAX* on Intel® hardware.
88
| Time to complete | 10 minutes
99

1010
## Purpose
1111

12-
TensorFlow* is a widely-used machine learning framework in the deep learning arena, demanding efficient computational resource utilization. To take full advantage of Intel® architecture and to extract maximum performance, the TensorFlow* framework has been optimized using Intel® oneDNN primitives. This sample demonstrates how to train an example neural network and shows how Intel-optimized TensorFlow* enables Intel® oneDNN calls by default. Intel-optimized TensorFlow* is available as part of the Intel® AI Tools.
12+
JAX is a high-performance numerical computing library that enables automatic differentiation. It provides features like just-in-time compilation and efficient parallelization for machine learning and scientific computing tasks.
1313

14-
This sample code shows how to get started with TensorFlow*. It implements an example neural network with one convolution layer and one ReLU layer. You can build and train a TensorFlow* neural network using a simple Python code. Also, by controlling the build-in environment variable, this sample attempts to demonstrate explicitly how Intel® oneDNN Primitives are called and shows their performance during the neural network training.
14+
This sample code shows how to get started with JAX on CPU. The sample code defines a simple neural network that trains on the MNIST dataset using JAX for parallel computations across multiple CPU cores. The network trains over multiple epochs, evaluates accuracy, and adjusts parameters using stochastic gradient descent across devices.
1515

1616
## Prerequisites
1717

1818
| Optimized for | Description
1919
|:--- |:---
2020
| OS | Ubuntu* 22.0.4 and newer
2121
| Hardware | Intel® Xeon® Scalable processor family
22-
| Software | TensorFlow
22+
| Software | JAX
2323

2424
> **Note**: AI and Analytics samples are validated on AI Tools Offline Installer. For the full list of validated platforms refer to [Platform Validation](https://github.com/oneapi-src/oneAPI-samples/tree/master?tab=readme-ov-file#platform-validation).
2525
2626
## Key Implementation Details
2727

28-
The sample includes one python file: TensorFlow_HelloWorld.py. it implements a simple neural network's training and inference
29-
- The training data is generated by `np.random`.
30-
- The neural network with one convolution layer and one ReLU layer is created by `tf.nn.conv2d` and `tf.nn.relu`.
31-
- The TF session is initialized by `tf.global_variables_initializer`.
32-
- The train is implemented via the below for-loop:
33-
```python
34-
for epoch in range(0, EPOCHNUM):
35-
for step in range(0, BS_TRAIN):
36-
x_batch = x_data[step*N:(step+1)*N, :, :, :]
37-
y_batch = y_data[step*N:(step+1)*N, :, :, :]
38-
s.run(train, feed_dict={x: x_batch, y: y_batch})
39-
```
40-
In order to show the harware information, you must export the environment variable `export ONEDNN_VERBOSE=1` to display the deep learning primitives trace during execution.
41-
>**Note**: For convenience, code line os.environ["ONEDNN_VERBOSE"] = "1" has been added in the body of the script as an alternative method to setting this variable.
42-
43-
Runtime settings for `ONEDNN_VERBOSE`, `KMP_AFFINITY`, and `Inter/Intra-op` Threads are set within the script. You can read more about these settings in this dedicated document: *[Maximize TensorFlow* Performance on CPU: Considerations and Recommendations for Inference Workloads](https://software.intel.com/en-us/articles/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference)*.
44-
45-
### Run the Sample on Intel® GPUs
46-
The sample code is CPU based, but you can run it using Intel® Extension for TensorFlow* with Intel® Data Center GPU Flex Series. If you are using the Intel GPU, refer to *[Intel GPU Software Installation Guide](https://intel.github.io/intel-extension-for-tensorflow/latest/docs/install/install_for_gpu.html)*. The sample should be able to run on GPU **without any code changes**.
47-
48-
For details, refer to the *[Quick Example on Intel CPU and GPU](https://intel.github.io/intel-extension-for-tensorflow/latest/examples/quick_example.html)* topic of the *Intel® Extension for TensorFlow** documentation.
28+
The getting-started sample code uses the python file 'spmd_mnist_classifier_fromscratch.py' under the examples directory in the
29+
[jax repository](https://github.com/google/jax/).
30+
It implements a simple neural network's training and inference for mnist images. The images are downloaded to a temporary directory when the example is run first.
31+
- **init_random_params** initializes the neural network weights and biases for each layer.
32+
- **predict** computes the forward pass of the network, applying weights, biases, and activations to inputs.
33+
- **loss** calculates the cross-entropy loss between predictions and target labels.
34+
- **spmd_update** performs parallel gradient updates across multiple devices using JAX’s pmap and lax.psum.
35+
- **accuracy** computes the accuracy of the model by predicting the class of each input in the batch and comparing it to the true target class. It uses the *jnp.argmax* function to find the predicted class and then computes the mean of correct predictions.
36+
- **data_stream** function generates batches of shuffled training data. It reshapes the data so that it can be split across multiple cores, ensuring that the batch size is divisible by the number of cores for parallel processing.
37+
- **training loop** trains the model for a set number of epochs, updating parameters and printing training/test accuracy after each epoch. The parameters are replicated across devices and updated in parallel using spmd_update. After each epoch, the model’s accuracy is evaluated on both training and test data using accuracy.
4938

5039
## Environment Setup
5140

5241
You will need to download and install the following toolkits, tools, and components to use the sample.
5342

5443
**1. Get Intel® AI Tools**
5544

56-
Required AI Tools: 'Intel® Extension for TensorFlow* - CPU'
45+
Required AI Tools: 'JAX'
5746
<br>If you have not already, select and install these Tools via [AI Tools Selector](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html). AI and Analytics samples are validated on AI Tools Offline Installer. It is recommended to select Offline Installer option in AI Tools Selector.<br>
5847
please see the [supported versions](https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-tools-selector.html).
5948

@@ -74,16 +63,14 @@ source <custom_path>/bin/activate
7463

7564
For the system with Intel CPU:
7665
```
77-
conda activate tensorflow
78-
```
79-
For the system with Intel GPU:
80-
```
81-
conda activate tensorflow-gpu
82-
```
66+
conda activate jax
67+
```
68+
8369
**4. Clone the GitHub repository**
8470
```
85-
git clone https://github.com/oneapi-src/oneAPI-samples.git
86-
cd oneAPI-samples/AI-and-Analytics/Getting-Started-Samples/IntelTensorFlow_GettingStarted
71+
git clone https://github.com/google/jax.git
72+
cd jax
73+
export PYTHONPATH=$PYTHONPATH:$(pwd)
8774
```
8875
## Run the Sample
8976

@@ -93,55 +80,53 @@ Go to the section which corresponds to the installation method chosen in [AI Too
9380
* [Docker](#docker)
9481
### AI Tools Offline Installer (Validated)/Conda/PIP
9582
```
96-
python TensorFlow_HelloWorld.py
83+
python examples/spmd_mnist_classifier_fromscratch.py
9784
```
9885
### Docker
9986
AI Tools Docker images already have Get Started samples pre-installed. Refer to [Working with Preset Containers](https://github.com/intel/ai-containers/tree/main/preset) to learn how to run the docker and samples.
10087
## Example Output
101-
1. With the initial run, you should see results similar to the following:
102-
103-
```
104-
0 0.4147554
105-
1 0.3561021
106-
2 0.33979267
107-
3 0.33283564
108-
4 0.32920069
109-
[CODE_SAMPLE_COMPLETED_SUCCESSFULLY]
110-
```
111-
2. Export `ONEDNN_VERBOSE` as 1 in the command line. The oneDNN run-time verbose trace should look similar to the following:
112-
```
113-
export ONEDNN_VERBOSE=1
114-
Windows: set ONEDNN_VERBOSE=1
115-
```
116-
>**Note**: The historical environment variables include `DNNL_VERBOSE` and `MKLDNN_VERBOSE`.
117-
118-
3. Run the sample again. You should see verbose results similar to the following:
119-
```
120-
2024-03-12 16:01:59.784340: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type CPU is enabled.
121-
onednn_verbose,info,oneDNN v3.2.0 (commit 8f2a00d86546e44501c61c38817138619febbb10)
122-
onednn_verbose,info,cpu,runtime:OpenMP,nthr:24
123-
onednn_verbose,info,cpu,isa:Intel AVX2 with Intel DL Boost
124-
onednn_verbose,info,gpu,runtime:none
125-
onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
126-
onednn_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:cdba::f0 dst_f32:p:blocked:Acdb16a::f0,,,10x4x3x3,0.00195312
127-
onednn_verbose,exec,cpu,convolution,brgconv:avx2,forward_training,src_f32::blocked:acdb::f0 wei_f32:ap:blocked:Acdb16a::f0 bia_f32::blocked:a::f0
128-
dst_f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:eltwise_relu ,alg:convolution_direct,mb4_ic4oc10_ih128oh128kh3sh1dh0ph1_iw128ow128kw3sw1dw0pw1,1.19702
129-
onednn_verbose,exec,cpu,eltwise,jit:avx2,backward_data,data_f32::blocked:abcd::f0 diff_f32::blocked:abcd::f0,attr-scratchpad:user ,alg:eltwise_relu alpha:0
130-
beta:0,4x128x128x10,0.112061
131-
onednn_verbose,exec,cpu,convolution,jit:avx2,backward_weights,src_f32::blocked:acdb::f0 wei_f32:ap:blocked:ABcd8b8a::f0 bia_undef::undef:::
132-
dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,mb4_ic4oc10_ih128oh128kh3sh1dh0ph1_iw128ow128kw3sw1dw0pw1,0.358887
133-
...
134-
135-
>**Note**: See the *[oneAPI Deep Neural Network Library Developer Guide and Reference](https://oneapi-src.github.io/oneDNN/dev_guide_verbose.html)* for more details on the verbose log.
88+
1. When the program is run, you should see results similar to the following:
13689

137-
4. Troubleshooting
138-
139-
If you receive an error message, troubleshoot the problem using the **Diagnostics Utility for Intel® oneAPI Toolkits**. The diagnostic utility provides configuration and system checks to help find missing dependencies, permissions errors, and other issues. See the *[Diagnostics Utility for Intel® oneAPI Toolkits User Guide](https://www.intel.com/content/www/us/en/develop/documentation/diagnostic-utility-user-guide/top.html)* for more information on using the utility.
140-
or ask support from https://github.com/intel/intel-extension-for-tensorflow
90+
```
91+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
92+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
93+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
94+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/
95+
Epoch 0 in 2.71 sec
96+
Training set accuracy 0.7381166815757751
97+
Test set accuracy 0.7516999840736389
98+
Epoch 1 in 2.35 sec
99+
Training set accuracy 0.81454998254776
100+
Test set accuracy 0.8277999758720398
101+
Epoch 2 in 2.33 sec
102+
Training set accuracy 0.8448166847229004
103+
Test set accuracy 0.8568999767303467
104+
Epoch 3 in 2.34 sec
105+
Training set accuracy 0.8626833558082581
106+
Test set accuracy 0.8715999722480774
107+
Epoch 4 in 2.30 sec
108+
Training set accuracy 0.8752999901771545
109+
Test set accuracy 0.8816999793052673
110+
Epoch 5 in 2.33 sec
111+
Training set accuracy 0.8839333653450012
112+
Test set accuracy 0.8899999856948853
113+
Epoch 6 in 2.37 sec
114+
Training set accuracy 0.8908833265304565
115+
Test set accuracy 0.8944999575614929
116+
Epoch 7 in 2.31 sec
117+
Training set accuracy 0.8964999914169312
118+
Test set accuracy 0.8986999988555908
119+
Epoch 8 in 2.28 sec
120+
Training set accuracy 0.9016000032424927
121+
Test set accuracy 0.9034000039100647
122+
Epoch 9 in 2.31 sec
123+
Training set accuracy 0.9060333371162415
124+
Test set accuracy 0.9059999585151672
125+
```
141126

142-
## Related Samples
127+
2. Troubleshooting
143128

144-
* [Intel Extension For TensorFlow Getting Started Sample](https://github.com/oneapi-src/oneAPI-samples/blob/development/AI-and-Analytics/Getting-Started-Samples/Intel_Extension_For_TensorFlow_GettingStarted/README.md)
129+
If you receive an error message, troubleshoot the problem using the **Diagnostics Utility for Intel® oneAPI Toolkits**. The diagnostic utility provides configuration and system checks to help find missing dependencies, permissions errors, and other issues. See the *[Diagnostics Utility for Intel® oneAPI Toolkits User Guide](https://www.intel.com/content/www/us/en/develop/documentation/diagnostic-utility-user-guide/top.html)* for more information on using the utility
145130

146131
## License
147132

AI-and-Analytics/Getting-Started-Samples/IntelJAX_GettingStarted/TensorFlow_HelloWorld.py

Lines changed: 0 additions & 93 deletions
This file was deleted.
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1-
source /opt/intel/oneapi/setvars.sh
2-
source activate tensorflow
3-
python TensorFlow_HelloWorld.py
1+
source $HOME/intel/oneapi/intelpython/bin/activate
2+
conda activate jax
3+
git clone https://github.com/google/jax.git
4+
cd jax
5+
export PYTHONPATH=$PYTHONPATH:$(pwd)
6+
python examples/spmd_mnist_classifier_fromscratch.py

AI-and-Analytics/Getting-Started-Samples/IntelJAX_GettingStarted/sample.json

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
2-
"guid": "111213A0-C930-45B4-820F-02472BABBF34",
3-
"name": "Intel® Optimization for TensorFlow* Getting Started",
2+
"guid": "9A6A140B-FBD0-4CB2-849A-9CAF15A6F3B1",
3+
"name": "Getting Started example for JAX CPU",
44
"categories": ["Toolkit/oneAPI AI And Analytics/Getting Started"],
5-
"description": "This sample illustrates how to train a TensorFlow model and run inference with oneMKL and oneDNN.",
5+
"description": "This sample illustrates how to train a JAX model and run inference",
66
"builder": ["cli"],
77
"languages": [{
88
"python": {}
@@ -11,11 +11,12 @@
1111
"targetDevice": ["CPU"],
1212
"ciTests": {
1313
"linux": [{
14-
"id": "tensorflow hello world",
14+
"id": "JAX CPU example",
1515
"steps": [
16-
"source /intel/oneapi/intelpython/bin/activate",
17-
"conda activate tensorflow",
18-
"python TensorFlow_HelloWorld.py"
16+
"git clone https://github.com/google/jax.git",
17+
"cd jax",
18+
"conda activate jax",
19+
"python examples/spmd_mnist_classifier_fromscratch.py"
1920
]
2021
}]
2122
},

0 commit comments

Comments
 (0)