Skip to content

Commit 42e41f1

Browse files
first implementation of TensorArray and preprocessor layer
1 parent 0387fd2 commit 42e41f1

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { TensorArray } from '../../tfcpatches/TensorArray';
3+
import { whileLayer } from './whileLayer';
4+
5+
export type PreprocessorParams = any
6+
7+
// TODO: hardcoded params
8+
const elementShape = [512, 512, 3]
9+
const weight = tf.scalar(0.007843137718737125)
10+
const bias = tf.scalar(1)
11+
12+
export function preprocessor(imgTensor: tf.Tensor4D, params: PreprocessorParams) {
13+
const batchSize = imgTensor.shape[0]
14+
const batchSizedArray1 = new TensorArray(batchSize, 'float32')
15+
const batchSizedArray2 = new TensorArray(batchSize, 'float32')
16+
17+
let unusedFlow = null
18+
const indices = tf.range(0, batchSize, 1)
19+
20+
// unstack
21+
unusedFlow = batchSizedArray1.scatter(indices, imgTensor, unusedFlow)
22+
23+
unusedFlow = whileLayer(batchSizedArray1, batchSizedArray2, batchSize, unusedFlow)
24+
25+
// stack
26+
const stacked = batchSizedArray2.gather(indices, unusedFlow, 'float32', elementShape)
27+
28+
return tf.add(tf.mul(stacked, weight), bias)
29+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { TensorArray } from '../../tfcpatches/TensorArray';
3+
4+
export function whileLayer(arr1: TensorArray, arr2: TensorArray, batchSize: number, unusedFlowIn: tf.Scalar): tf.Scalar {
5+
// TODO
6+
7+
const unusedFlowOut = tf.scalar(0)
8+
return unusedFlowOut
9+
}

src/tfcpatches/TensorArray.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
export class TensorArray {
4+
private _tensors: tf.Tensor[] | undefined
5+
6+
constructor(
7+
private _size: number,
8+
private _dtype: tf.DataType = null,
9+
private _elementShape: number[] = null,
10+
private _dynamicSize: boolean = false,
11+
private _clearAfterRead: boolean = true,
12+
private _identicalElementShapes: boolean = false,
13+
private _tensorArrayName: string = null
14+
) {
15+
if (_size) {
16+
this._tensors = Array(_size).fill(0).map(_ => tf.scalar(0))
17+
}
18+
}
19+
20+
public scatter(indices: tf.Tensor1D, value: tf.Tensor, unusedFlow: tf.Scalar): tf.Scalar {
21+
if (indices.shape.length !== 1) {
22+
throw new Error(`scatter - expected rank of indices (${indices.shape.length}) to be 1`)
23+
}
24+
if (indices.shape[0] > this._size) {
25+
throw new Error(`scatter - expected indices.shape[0] (${indices.shape[0]}) to be >= this._size (${this.size})`)
26+
}
27+
if (indices.shape[0] !== value.shape[0]) {
28+
throw new Error(`scatter - expected indices.shape[0] (${indices.shape[0]}) to equal value.shape[0] (${value.shape[0]})`)
29+
}
30+
31+
const unstacked = tf.unstack(value, 0)
32+
Array.from(indices.dataSync()).forEach((idx, i) => {
33+
this._tensors[idx] = unstacked[i]
34+
})
35+
36+
const unusedFlowOut = tf.scalar(0)
37+
return unusedFlowOut
38+
}
39+
40+
public gather(indices: tf.Tensor1D, unusedFlow: tf.Scalar, dtype?: tf.DataType, elementShape?: number[]) : tf.Tensor {
41+
const tensors = Array.from(indices.dataSync()).map(idx => this._tensors[idx])
42+
return tf.concat(tensors)
43+
}
44+
45+
public size(unusedFlow: tf.Scalar) {
46+
return this._size
47+
}
48+
}

0 commit comments

Comments
 (0)