Skip to content

Commit 2900e67

Browse files
authored
Merge pull request #67219 from hning86/patch-57
Update how-to-train-tensorflow.md
2 parents 167cf31 + e392e55 commit 2900e67

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

articles/machine-learning/service/how-to-train-tensorflow.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ ms.date: 12/04/2018
1313
ms.custom: seodec18
1414
---
1515

16-
# Train TensorFlow models with Azure Machine Learning service
16+
# Train TensorFlow and Keras models with Azure Machine Learning service
1717

1818
For deep neural network (DNN) training using TensorFlow, Azure Machine Learning provides a custom `TensorFlow` class of the `Estimator`. The Azure SDK's `TensorFlow` estimator (not to be conflated with the [`tf.estimator.Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) class) enables you to easily submit TensorFlow training jobs for both single-node and distributed runs on Azure compute.
1919

@@ -34,7 +34,7 @@ tf_est = TensorFlow(source_directory='./my-tf-proj',
3434
script_params=script_params,
3535
compute_target=compute_target,
3636
entry_script='train.py',
37-
conda_packages=['scikit-learn'],
37+
conda_packages=['scikit-learn'], # in case you need scikit-learn in train.py
3838
use_gpu=True)
3939
```
4040

@@ -56,6 +56,21 @@ Then, submit the TensorFlow job:
5656
run = exp.submit(tf_est)
5757
```
5858

59+
## Keras support
60+
[Keras](https://keras.io/) is a popular high-level DNN Python API that supports TensorFlow, CNTK or Theano as backends. If you use TensorFlow as backend, you can easily use the TensFlow estimator to train a Keras model. Here is an example of a TensorFlow estimator with Keras added to it:
61+
62+
```Python
63+
from azureml.train.dnn import TensorFlow
64+
65+
keras_est = TensorFlow(source_directory='./my-keras-proj',
66+
script_params=script_params,
67+
compute_target=compute_target,
68+
entry_script='keras_train.py',
69+
conda_packages=['keras'], # just add keras through conda
70+
use_gpu=True)
71+
```
72+
The above TensorFlow estimator constructor simply instructs Azure Machine Learning service to install Keras through Conda to the execution environment. And your `keras_train.py` can then import Keras API to train a Keras model. For a complete example, please see [this Jupyter notebook](https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/training-with-deep-learning/train-hyperparameter-tune-deploy-with-keras/train-hyperparameter-tune-deploy-with-keras.ipynb).
73+
5974
## Distributed training
6075
The TensorFlow Estimator also enables you to train your models at scale across CPU and GPU clusters of Azure VMs. You can easily run distributed TensorFlow training with a few API calls, while Azure Machine Learning will manage behind the scenes all the infrastructure and orchestration needed to carry out these workloads.
6176

articles/machine-learning/service/toc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
href: how-to-train-ml-models.md
9999
- name: Use PyTorch
100100
href: how-to-train-pytorch.md
101-
- name: Use TensorFlow
101+
- name: Use TensorFlow and Keras
102102
href: how-to-train-tensorflow.md
103103
- name: Tune hyperparameters
104104
displayName: parameter

0 commit comments

Comments
 (0)