You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _mobile/android.md
+310-5
Original file line number
Diff line number
Diff line change
@@ -10,11 +10,316 @@ published: true
10
10
11
11
# Android
12
12
13
-
{% highlight python %}
13
+
## Quick start with Hello World
14
14
15
-
#!/usr/bin/python3
16
-
print('Hello World!')
15
+
The easiest way to start playing with pytorch android is to checkout on githyb our ['Hello World' application](https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp)
17
16
18
-
{% endhighlight %}
17
+
This application runs [torchscript serialized](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py)
18
+
[torch vision pretrained resnet18 model](https://pytorch.org/docs/stable/torchvision/models.html) on [image.jpg](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/assets/image.jpg) which packaged inside the app as android asset.
If [android sdk]() and [android ndk]() is already installed you can install this application to connected android device or emulator with:
25
+
```
26
+
./gradlew installDebug
27
+
```
28
+
29
+
We recommend you to open this project in [Android Studio](https://developer.android.com/studio),
30
+
in that case you will be able to install android ndk and android sdk using Android Studio UI.
31
+
32
+
The easiest way to add 'pytorch android' to the app is adding [gradle dependencies](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/build.gradle#L28-L29) to your build.gradle:
Where org.pytorch:pytorch_android is the main dependency with android api, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
45
+
Further in this doc you can find how to rebuild it only for specific list of android abis.
46
+
47
+
org.pytorch:pytorch_android_torchvision - library with several utility functions for converting `android.media.Image` and `android.graphics.Bitmap` to tensors.
48
+
49
+
All logic happens in [org.pytorch.helloworld.MainActivity](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java#L31-L69):
`org.pytorch.Module` represents `torch::jit::script::Module` that can be loaded with `load` method specifying file path to the serialized to file model.
`org.pytorch.torchvision.TensorImageUtils` is part of 'org.pytorch:pytorch_android_torchvision' library.
65
+
`TensorImageUtils#bitmapToFloat32Tensor` method creates tensor in [torch vision format](https://pytorch.org/docs/stable/torchvision/models.html) using `android.graphics.Bitmap` as a source.
66
+
67
+
> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224.
68
+
> The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
69
+
70
+
`inputTensor`'s shape is 1x3xHxW, where H and W are bitmap height and width appropriately.
`org.pytorch.Module#forward` method runs loaded module's `forward` method and gets result as `org.pytorch.Tensor` outputTensor with shape 1x1000.
78
+
It's content is retrieved using `org.pytorch.Tensor#getDataAsFloatArray()` method that returns java array of floats with scores for every image net class.
79
+
80
+
After that we just find index with maximum score and retrieve predicted class name from `ImageNetClasses.IMAGENET_CLASSES` array that contains all ImageNet classes.
Now you are ready to start using other torchsript models on android, changing [`model.pt`](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/assets/model.pt) and rebuilding the project :)
95
+
96
+
To prepare serialize the model you can use python [script](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py):
97
+
```
98
+
import torch
99
+
import torchvision
100
+
101
+
model = torchvision.models.resnet18(pretrained=True)
More details about torchscript you can find in [tutorials and docs on pytorch.org](https://pytorch.org/docs/stable/jit.html)
109
+
110
+
In the following sections you can find detailed explanation of pytorch android api, code walk through for bigger [demo application](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp), implementation details of api and how to customize and build it from the source.
111
+
112
+
## Pytorch demo app
113
+
114
+
Bigger example of application that does image classification from android camera output and text classification you can find in the [same github repo](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp).
115
+
116
+
To get device camera output in it uses [android cameraX api](https://developer.android.com/training/camerax
void analyzeImage(android.media.Image, int rotationDegrees)
142
+
```
143
+
144
+
Where [analyzeImage](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java#L128) method analyzes camera output, `android.media.Image`.
145
+
146
+
It uses aforementioned [`TensorImageUtils#imageYUV420CenterCropToFloat32Tensor`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java#L90) method to convert `android.media.Image` in `YUV420` format to input tensor.
147
+
148
+
After getting predicted scores from the model it [finds top K classes](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java#L153-L161) with the highest scores and shows on the UI.
149
+
150
+
## Building pytorch android from source
151
+
152
+
In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes.
153
+
154
+
For this you can use `./scripts/build_pytorch_android.sh` script.
155
+
```
156
+
git clone https://github.com/pytorch/pytorch.git
157
+
cd pytorch
158
+
sh ./scripts/build_pytorch_android.sh
159
+
```
160
+
161
+
Its workflow contains several steps:
162
+
1. Builds libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)
163
+
2. Creates symbolic links to the results of those builds:
164
+
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries
165
+
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.
166
+
3. And finally runs `gradle` in `android/pytorch_android` directory with task `assembleRelease`
167
+
168
+
Script requires that android sdk, android ndk and gradle are installed.
169
+
They are specified as environment variables:
170
+
171
+
`ANDROID_HOME` - path to [android sdk](https://developer.android.com/studio/command-line/sdkmanager.html)
172
+
173
+
`ANDROID_NDK` - path to [android ndk](https://developer.android.com/studio/projects/install-ndk)
174
+
175
+
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
176
+
177
+
178
+
After successful build you should see the result as aar file:
At the moment for the case of using aar files directly we need additional configuration due to packaging specific (libfbjni.so is packaged in both pytorch_android_fbjni.aar and pytorch_android.aar).
213
+
```
214
+
packagingOptions {
215
+
pickFirst "**/libfbjni.so"
216
+
}
217
+
```
218
+
219
+
## API Details
220
+
221
+
Main part of java api includes 3 classes:
222
+
```
223
+
org.pytorch.Module
224
+
org.pytorch.IValue
225
+
org.pytorch.Tensor
226
+
```
227
+
228
+
If the reader is familiar with pytorch python api, we can think that org.pytorch.Tensor represents torch.tensor, org.pytorch.Module torch.Module<?>, while org.pytorch.IValue represents value of torchscript variable, supporting all its types. ( https://pytorch.org/docs/stable/jit.html#types )
Where the first parameter `long[] shape` is shape of the Tensor as array of longs.
259
+
260
+
Content of the Tensor can be provided either as (a) java array or (b) as java.nio.DirectByteBuffer of proper type with native bit order.
261
+
262
+
In case of (a) proper DirectByteBuffer will be created internally. (b) case has an advantage that user can keep the reference to DirectByteBuffer and change its content in future for the next run, avoiding allocation of DirectByteBuffer for repeated runs.
263
+
264
+
Java’s primitive type byte is signed and java does not have unsigned 8 bit type. For dtype=uint8 api uses byte that will be reinterpretted as uint8 on native side. On java side unsigned value of byte can be read as (byte & 0xFF).
265
+
266
+
#### Tensor content layout
267
+
268
+
Tensor content is represented as a one dimensional array (buffer),
269
+
where the first element has all zero indexes T\[0, ... 0\].
270
+
271
+
Lets assume tensor shape is {d<sub>0</sub>, ... d<sub>n-1</sub>} and d<sub>n-1</sub> > 0.
272
+
273
+
The second element will be T\[0, ... 1\] and the last one T\[d<sub>0</sub>-1, ... d<sub>n-1</sub> - 1\]
274
+
275
+
Tensor has methods to check its dtype:
276
+
```
277
+
int dtype()
278
+
```
279
+
That returns one of the dtype codes:
280
+
```
281
+
Tensor.DTYPE_UINT8
282
+
Tensor.DTYPE_INT8
283
+
Tensor.DTYPE_INT32
284
+
Tensor.DTYPE_FLOAT32
285
+
Tensor.DTYPE_INT64
286
+
Tensor.DTYPE_FLOAT64
287
+
```
288
+
289
+
The data of Tensor can be read as java array:
290
+
```
291
+
byte[] getDataAsUnsignedByteArray()
292
+
byte[] getDataAsByteArray()
293
+
int[] getDataAsIntArray()
294
+
long[] getDataAsLongArray()
295
+
float[] getDataAsFloatArray()
296
+
double[] getDataAsDoubleArray()
297
+
```
298
+
These methods throw IllegalStateException if called for inappropriate dtype.
IValue represents a torchscript variable that can be one of the supported (by torchscript) types ( https://pytorch.org/docs/stable/jit.html#types ). IValue is a tagged union. For every supported type it has a factory method, method to check the type and a getter method to retrieve a value.
304
+
Getters throw IllegalStateException if called for inappropriate type.
Module is a wrapper of torch.jit.ScriptModule (`torch::jit::script::Module` in pytorch c++ api) which can be constructed with factory method load providing absolute path to the file with serialized torchscript.
for running a particular method of the script module.
314
+
```
315
+
IValue IValue.forward(IValue... inputs)
316
+
```
317
+
Shortcut to run 'forward' method.
318
+
319
+
```
320
+
IValue IValue.destroy()
321
+
```
322
+
Explicitly destructs native (C++) part of the Module, `torch::jit::script::Module`.
323
+
324
+
As fbjni library destructs native part automatically when current `org.pytorch.Module` instance will be collected by Java GC, the instance will not leak if this method is not called, but timing of deletion and the thread will be at the whim of the Java GC. If you want to control the thread and timing of the destructor, you should call this method explicitly.
0 commit comments