Skip to content

Commit 160fbcc

Browse files
added NeuralNetwork base class which provides some common functionality, such as making networks trainable
1 parent 57b51bd commit 160fbcc

File tree

8 files changed

+425
-75
lines changed

8 files changed

+425
-75
lines changed

src/commons/NeuralNetwork.ts

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ParamMapping } from './types';
4+
5+
export class NeuralNetwork<TNetParams> {
6+
7+
protected _params: TNetParams | undefined = undefined
8+
protected _paramMappings: ParamMapping[] = []
9+
10+
public get params(): TNetParams | undefined {
11+
return this._params
12+
}
13+
14+
public get paramMappings(): ParamMapping[] {
15+
return this._paramMappings
16+
}
17+
18+
public getParamFromPath(paramPath: string): tf.Tensor {
19+
const { obj, objProp } = this.traversePropertyPath(paramPath)
20+
return obj[objProp]
21+
}
22+
23+
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
24+
const { obj, objProp } = this.traversePropertyPath(paramPath)
25+
obj[objProp].dispose()
26+
obj[objProp] = tensor
27+
}
28+
29+
public getParamList() {
30+
return this._paramMappings.map(({ paramPath }) => ({
31+
path: paramPath,
32+
tensor: this.getParamFromPath(paramPath)
33+
}))
34+
}
35+
36+
public getTrainableParams() {
37+
return this.getParamList().filter(param => param.tensor instanceof tf.Variable)
38+
}
39+
40+
public getFrozenParams() {
41+
return this.getParamList().filter(param => !(param.tensor instanceof tf.Variable))
42+
}
43+
44+
public variable() {
45+
this.getFrozenParams().forEach(({ path, tensor }) => {
46+
this.reassignParamFromPath(path, tf.variable(tensor))
47+
})
48+
}
49+
50+
public freeze() {
51+
this.getTrainableParams().forEach(({ path, tensor }) => {
52+
this.reassignParamFromPath(path, tf.tensor(tensor as any))
53+
})
54+
}
55+
56+
public dispose() {
57+
this.getParamList().forEach(param => param.tensor.dispose())
58+
this._params = undefined
59+
}
60+
61+
private traversePropertyPath(paramPath: string) {
62+
if (!this.params) {
63+
throw new Error(`traversePropertyPath - model has no loaded params`)
64+
}
65+
66+
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
67+
if (!res.nextObj.hasOwnProperty(objProp)) {
68+
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`)
69+
}
70+
71+
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] }
72+
}, { nextObj: this.params })
73+
74+
const { obj, objProp } = result
75+
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
76+
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`)
77+
}
78+
79+
return { obj, objProp }
80+
}
81+
}

src/commons/extractConvParamsFactory.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { ConvParams, ExtractWeightsFunction } from './types';
3+
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
44

5-
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction) {
5+
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
66
return function (
77
channelsIn: number,
88
channelsOut: number,
9-
filterSize: number
9+
filterSize: number,
10+
mappedPrefix: string
1011
): ConvParams {
1112
const filters = tf.tensor4d(
1213
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
1314
[filterSize, filterSize, channelsIn, channelsOut]
1415
)
1516
const bias = tf.tensor1d(extractWeights(channelsOut))
1617

18+
paramMappings.push(
19+
{ paramPath: `${mappedPrefix}/filters` },
20+
{ paramPath: `${mappedPrefix}/bias` }
21+
)
22+
1723
return {
1824
filters,
1925
bias

src/commons/extractWeightEntry.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { isTensor } from './isTensor';
2+
3+
export function extractWeightEntry(weightMap: any, path: string, paramRank: number) {
4+
const tensor = weightMap[path]
5+
6+
if (!isTensor(tensor, paramRank)) {
7+
throw new Error(`expected weightMap[${path}] to be a Tensor${paramRank}D, instead have ${tensor}`)
8+
}
9+
10+
return { path, tensor }
11+
}

src/commons/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,8 @@ export type BatchReshapeInfo = {
1313
paddingX: number
1414
paddingY: number
1515
}
16+
17+
export type ParamMapping = {
18+
originalPath?: string
19+
paramPath: string
20+
}

src/faceLandmarkNet/FaceLandmarkNet.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { convLayer } from '../commons/convLayer';
4+
import { NeuralNetwork } from '../commons/NeuralNetwork';
45
import { ConvParams } from '../commons/types';
56
import { NetInput } from '../NetInput';
67
import { Point } from '../Point';
@@ -21,9 +22,7 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4
2122
return tf.maxPool(x, [2, 2], strides, 'valid')
2223
}
2324

24-
export class FaceLandmarkNet {
25-
26-
private _params: NetParams
25+
export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
2726

2827
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
2928
if (weightsOrUrl instanceof Float32Array) {
@@ -34,11 +33,23 @@ export class FaceLandmarkNet {
3433
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
3534
throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array')
3635
}
37-
this._params = await loadQuantizedParams(weightsOrUrl)
36+
const {
37+
paramMappings,
38+
params
39+
} = await loadQuantizedParams(weightsOrUrl)
40+
41+
this._paramMappings = paramMappings
42+
this._params = params
3843
}
3944

4045
public extractWeights(weights: Float32Array) {
41-
this._params = extractParams(weights)
46+
const {
47+
paramMappings,
48+
params
49+
} = extractParams(weights)
50+
51+
this._paramMappings = paramMappings
52+
this._params = params
4253
}
4354

4455
public forwardInput(input: NetInput): tf.Tensor2D {

src/faceLandmarkNet/extractParams.ts

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,62 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
44
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
5+
import { ParamMapping } from '../commons/types';
56
import { FCParams, NetParams } from './types';
67

7-
export function extractParams(weights: Float32Array): NetParams {
8+
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
9+
const paramMappings: ParamMapping[] = []
10+
811
const {
912
extractWeights,
1013
getRemainingWeights
1114
} = extractWeightsFactory(weights)
1215

13-
const extractConvParams = extractConvParamsFactory(extractWeights)
16+
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
1417

15-
function extractFcParams(channelsIn: number, channelsOut: number,): FCParams {
18+
function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams {
1619
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
1720
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
21+
22+
paramMappings.push(
23+
{ paramPath: `${mappedPrefix}/weights` },
24+
{ paramPath: `${mappedPrefix}/bias` }
25+
)
26+
1827
return {
1928
weights: fc_weights,
2029
bias: fc_bias
2130
}
2231
}
2332

24-
const conv0_params = extractConvParams(3, 32, 3)
25-
const conv1_params = extractConvParams(32, 64, 3)
26-
const conv2_params = extractConvParams(64, 64, 3)
27-
const conv3_params = extractConvParams(64, 64, 3)
28-
const conv4_params = extractConvParams(64, 64, 3)
29-
const conv5_params = extractConvParams(64, 128, 3)
30-
const conv6_params = extractConvParams(128, 128, 3)
31-
const conv7_params = extractConvParams(128, 256, 3)
32-
const fc0_params = extractFcParams(6400, 1024)
33-
const fc1_params = extractFcParams(1024, 136)
33+
const conv0_params = extractConvParams(3, 32, 3, 'conv0_params')
34+
const conv1_params = extractConvParams(32, 64, 3, 'conv1_params')
35+
const conv2_params = extractConvParams(64, 64, 3, 'conv2_params')
36+
const conv3_params = extractConvParams(64, 64, 3, 'conv3_params')
37+
const conv4_params = extractConvParams(64, 64, 3, 'conv4_params')
38+
const conv5_params = extractConvParams(64, 128, 3, 'conv5_params')
39+
const conv6_params = extractConvParams(128, 128, 3, 'conv6_params')
40+
const conv7_params = extractConvParams(128, 256, 3, 'conv7_params')
41+
const fc0_params = extractFcParams(6400, 1024, 'fc0_params')
42+
const fc1_params = extractFcParams(1024, 136, 'fc1_params')
3443

3544
if (getRemainingWeights().length !== 0) {
3645
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
3746
}
3847

3948
return {
40-
conv0_params,
41-
conv1_params,
42-
conv2_params,
43-
conv3_params,
44-
conv4_params,
45-
conv5_params,
46-
conv6_params,
47-
conv7_params,
48-
fc0_params,
49-
fc1_params
49+
paramMappings,
50+
params: {
51+
conv0_params,
52+
conv1_params,
53+
conv2_params,
54+
conv3_params,
55+
conv4_params,
56+
conv5_params,
57+
conv6_params,
58+
conv7_params,
59+
fc0_params,
60+
fc1_params
61+
}
5062
}
5163
}
Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,38 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { extractWeightEntry } from '../commons/extractWeightEntry';
34
import { loadWeightMap } from '../commons/loadWeightMap';
4-
import { ConvParams } from '../commons/types';
5+
import { ConvParams, ParamMapping } from '../commons/types';
56
import { FCParams, NetParams } from './types';
6-
import { isTensor4D, isTensor1D, isTensor2D } from '../commons/isTensor';
77

88
const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
99

10-
function extractorsFactory(weightMap: any) {
10+
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
1111

12-
function extractConvParams(prefix: string): ConvParams {
13-
const params = {
14-
filters: weightMap[`${prefix}/kernel`] as tf.Tensor4D,
15-
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D
12+
function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
13+
const filtersEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 4)
14+
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
15+
paramMappings.push(
16+
{ originalPath: filtersEntry.path, paramPath: `${mappedPrefix}/filters` },
17+
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
18+
)
19+
return {
20+
filters: filtersEntry.tensor as tf.Tensor4D,
21+
bias: biasEntry.tensor as tf.Tensor1D
1622
}
17-
18-
if (!isTensor4D(params.filters)) {
19-
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor4D, instead have ${params.filters}`)
20-
}
21-
22-
if (!isTensor1D(params.bias)) {
23-
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
24-
}
25-
26-
return params
2723
}
2824

29-
function extractFcParams(prefix: string): FCParams {
30-
const params = {
31-
weights: weightMap[`${prefix}/kernel`] as tf.Tensor2D,
32-
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D
33-
}
34-
35-
if (!isTensor2D(params.weights)) {
36-
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor2D, instead have ${params.weights}`)
25+
function extractFcParams(prefix: string, mappedPrefix: string): FCParams {
26+
const weightsEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 2)
27+
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
28+
paramMappings.push(
29+
{ originalPath: weightsEntry.path, paramPath: `${mappedPrefix}/weights` },
30+
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
31+
)
32+
return {
33+
weights: weightsEntry.tensor as tf.Tensor2D,
34+
bias: biasEntry.tensor as tf.Tensor1D
3735
}
38-
39-
if (!isTensor1D(params.bias)) {
40-
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
41-
}
42-
43-
return params
4436
}
4537

4638
return {
@@ -49,24 +41,30 @@ function extractorsFactory(weightMap: any) {
4941
}
5042
}
5143

52-
export async function loadQuantizedParams(uri: string | undefined): Promise<NetParams> {
44+
export async function loadQuantizedParams(
45+
uri: string | undefined
46+
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
47+
5348
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
49+
const paramMappings: ParamMapping[] = []
5450

5551
const {
5652
extractConvParams,
5753
extractFcParams
58-
} = extractorsFactory(weightMap)
54+
} = extractorsFactory(weightMap, paramMappings)
5955

60-
return {
61-
conv0_params: extractConvParams('conv2d_0'),
62-
conv1_params: extractConvParams('conv2d_1'),
63-
conv2_params: extractConvParams('conv2d_2'),
64-
conv3_params: extractConvParams('conv2d_3'),
65-
conv4_params: extractConvParams('conv2d_4'),
66-
conv5_params: extractConvParams('conv2d_5'),
67-
conv6_params: extractConvParams('conv2d_6'),
68-
conv7_params: extractConvParams('conv2d_7'),
69-
fc0_params: extractFcParams('dense'),
70-
fc1_params: extractFcParams('logits')
56+
const params = {
57+
conv0_params: extractConvParams('conv2d_0', 'conv0_params'),
58+
conv1_params: extractConvParams('conv2d_1', 'conv1_params'),
59+
conv2_params: extractConvParams('conv2d_2', 'conv2_params'),
60+
conv3_params: extractConvParams('conv2d_3', 'conv3_params'),
61+
conv4_params: extractConvParams('conv2d_4', 'conv4_params'),
62+
conv5_params: extractConvParams('conv2d_5', 'conv5_params'),
63+
conv6_params: extractConvParams('conv2d_6', 'conv6_params'),
64+
conv7_params: extractConvParams('conv2d_7', 'conv7_params'),
65+
fc0_params: extractFcParams('dense', 'fc0_params'),
66+
fc1_params: extractFcParams('logits', 'fc1_params')
7167
}
68+
69+
return { params, paramMappings }
7270
}

0 commit comments

Comments
 (0)