Skip to content

Commit 7cbd750

Browse files
init AgeGenderNet
1 parent 2d885bb commit 7cbd750

14 files changed

+395
-20
lines changed

src/ageGenderNet/AgeGenderNet.ts

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { NetInput, NeuralNetwork, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
3+
4+
import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
5+
import { seperateWeightMaps } from '../faceProcessor/util';
6+
import { TinyXception } from '../xception/TinyXception';
7+
import { extractParams } from './extractParams';
8+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
9+
import { NetOutput, NetParams } from './types';
10+
11+
export class AgeGenderNet extends NeuralNetwork<NetParams> {
12+
13+
private _faceFeatureExtractor: TinyXception
14+
15+
constructor(faceFeatureExtractor: TinyXception = new TinyXception(2)) {
16+
super('AgeGenderNet')
17+
this._faceFeatureExtractor = faceFeatureExtractor
18+
}
19+
20+
public get faceFeatureExtractor(): TinyXception {
21+
return this._faceFeatureExtractor
22+
}
23+
24+
public runNet(input: NetInput | tf.Tensor4D): NetOutput {
25+
26+
const { params } = this
27+
28+
if (!params) {
29+
throw new Error(`${this._name} - load model before inference`)
30+
}
31+
32+
return tf.tidy(() => {
33+
const bottleneckFeatures = input instanceof NetInput
34+
? this.faceFeatureExtractor.forwardInput(input)
35+
: input
36+
37+
const bottleneckFeatures2d = bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1)
38+
const age = fullyConnectedLayer(bottleneckFeatures2d, params.fc.age).as1D()
39+
const gender = fullyConnectedLayer(bottleneckFeatures2d, params.fc.gender)
40+
return { age, gender }
41+
})
42+
}
43+
44+
public forwardInput(input: NetInput | tf.Tensor4D): NetOutput {
45+
const { age, gender } = this.runNet(input)
46+
return tf.tidy(() => ({ age, gender: tf.softmax(gender) }))
47+
}
48+
49+
public async forward(input: TNetInput): Promise<NetOutput> {
50+
return this.forwardInput(await toNetInput(input))
51+
}
52+
53+
public async predictAgeAndGender(input: TNetInput): Promise<{ age: number, gender: string, genderProbability: number }> {
54+
const netInput = await toNetInput(input)
55+
const out = await this.forwardInput(netInput)
56+
const age = await out.age.data()[0] as number
57+
const probMale = await out.gender.data()[0] as number
58+
59+
const isMale = probMale > 0.5
60+
const gender = isMale ? 'male' : 'female'
61+
const genderProbability = isMale ? probMale : (1 - probMale)
62+
63+
return { age, gender, genderProbability }
64+
}
65+
66+
protected getDefaultModelName(): string {
67+
return 'age_gender_model'
68+
}
69+
70+
public dispose(throwOnRedispose: boolean = true) {
71+
this.faceFeatureExtractor.dispose(throwOnRedispose)
72+
super.dispose(throwOnRedispose)
73+
}
74+
75+
public loadClassifierParams(weights: Float32Array) {
76+
const { params, paramMappings } = this.extractClassifierParams(weights)
77+
this._params = params
78+
this._paramMappings = paramMappings
79+
}
80+
81+
public extractClassifierParams(weights: Float32Array) {
82+
return extractParams(weights)
83+
}
84+
85+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
86+
87+
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
88+
89+
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
90+
91+
return extractParamsFromWeigthMap(classifierMap)
92+
}
93+
94+
protected extractParams(weights: Float32Array) {
95+
96+
const classifierWeightSize = (512 * 1) + (512 * 2)
97+
98+
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
99+
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
100+
101+
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
102+
return this.extractClassifierParams(classifierWeights)
103+
}
104+
}

src/ageGenderNet/extractParams.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
2+
3+
import { NetParams } from './types';
4+
5+
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } {
6+
7+
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = []
8+
9+
const {
10+
extractWeights,
11+
getRemainingWeights
12+
} = TfjsImageRecognitionBase.extractWeightsFactory(weights)
13+
14+
const extractFCParams = TfjsImageRecognitionBase.extractFCParamsFactory(extractWeights, paramMappings)
15+
16+
const age = extractFCParams(512, 1, 'fc_age')
17+
const gender = extractFCParams(512, 2, 'fc_gender')
18+
19+
if (getRemainingWeights().length !== 0) {
20+
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
21+
}
22+
23+
return {
24+
paramMappings,
25+
params: { fc: { age, gender } }
26+
}
27+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
3+
4+
import { NetParams } from './types';
5+
6+
export function extractParamsFromWeigthMap(
7+
weightMap: tf.NamedTensorMap
8+
): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } {
9+
10+
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = []
11+
12+
const extractWeightEntry = TfjsImageRecognitionBase.extractWeightEntryFactory(weightMap, paramMappings)
13+
14+
function extractFcParams(prefix: string): TfjsImageRecognitionBase.FCParams {
15+
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
16+
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
17+
return { weights, bias }
18+
}
19+
20+
const params = {
21+
fc: {
22+
age: extractFcParams('fc_age'),
23+
gender: extractFcParams('fc_gender')
24+
}
25+
}
26+
27+
TfjsImageRecognitionBase.disposeUnusedWeightTensors(weightMap, paramMappings)
28+
29+
return { params, paramMappings }
30+
}

src/ageGenderNet/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from './AgeGenderNet';
2+
export * from './types';

src/ageGenderNet/types.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
3+
4+
export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
5+
6+
export type NetParams = {
7+
fc: {
8+
age: TfjsImageRecognitionBase.FCParams
9+
gender: TfjsImageRecognitionBase.FCParams
10+
}
11+
}

src/faceFeatureExtractor/denseBlock.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import * as tf from '@tensorflow/tfjs-core';
22
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
33

4-
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
4+
import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv';
55
import { DenseBlock3Params, DenseBlock4Params } from './types';
66

77
export function denseBlock3(

src/faceFeatureExtractor/extractorsFactory.ts

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,11 @@
1-
import * as tf from '@tensorflow/tfjs-core';
21
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
32

43
import { DenseBlock3Params, DenseBlock4Params } from './types';
54

65
export function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeightsFunction, paramMappings: TfjsImageRecognitionBase.ParamMapping[]) {
76

8-
function extractSeparableConvParams(channelsIn: number, channelsOut: number, mappedPrefix: string): TfjsImageRecognitionBase.SeparableConvParams {
9-
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1])
10-
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
11-
const bias = tf.tensor1d(extractWeights(channelsOut))
12-
13-
paramMappings.push(
14-
{ paramPath: `${mappedPrefix}/depthwise_filter` },
15-
{ paramPath: `${mappedPrefix}/pointwise_filter` },
16-
{ paramPath: `${mappedPrefix}/bias` }
17-
)
18-
19-
return new TfjsImageRecognitionBase.SeparableConvParams(
20-
depthwise_filter,
21-
pointwise_filter,
22-
bias
23-
)
24-
}
25-
267
const extractConvParams = TfjsImageRecognitionBase.extractConvParamsFactory(extractWeights, paramMappings)
8+
const extractSeparableConvParams = TfjsImageRecognitionBase.extractSeparableConvParamsFactory(extractWeights, paramMappings)
279

2810
function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
2911

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export {
66

77
export * from 'tfjs-image-recognition-base';
88

9+
export * from './ageGenderNet/index';
910
export * from './classes/index';
1011
export * from './dom/index'
1112
export * from './faceExpressionNet/index';

src/xception/TinyXception.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import {
3+
NetInput,
4+
NeuralNetwork,
5+
normalize,
6+
range,
7+
TfjsImageRecognitionBase,
8+
TNetInput,
9+
toNetInput,
10+
} from 'tfjs-image-recognition-base';
11+
12+
import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv';
13+
import { extractParams } from './extractParams';
14+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
15+
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
16+
17+
function conv(x: tf.Tensor4D, params: TfjsImageRecognitionBase.ConvParams, stride: [number, number]): tf.Tensor4D {
18+
return tf.add(tf.conv2d(x, params.filters, stride, 'same'), params.bias)
19+
}
20+
21+
function reductionBlock(x: tf.Tensor4D, params: ReductionBlockParams, isActivateInput: boolean = true): tf.Tensor4D {
22+
let out = isActivateInput ? tf.relu(x) : x
23+
out = depthwiseSeparableConv(out, params.separable_conv0, [1, 1])
24+
out = depthwiseSeparableConv(tf.relu(out), params.separable_conv1, [1, 1])
25+
out = tf.maxPool(out, [3, 3], [2, 2], 'same')
26+
out = tf.add(out, conv(x, params.expansion_conv, [2, 2]))
27+
return out
28+
}
29+
30+
function mainBlock(x: tf.Tensor4D, params: MainBlockParams): tf.Tensor4D {
31+
let out = depthwiseSeparableConv(tf.relu(x), params.separable_conv0, [1, 1])
32+
out = depthwiseSeparableConv(tf.relu(out), params.separable_conv1, [1, 1])
33+
out = depthwiseSeparableConv(tf.relu(out), params.separable_conv2, [1, 1])
34+
out = tf.add(out, x)
35+
return out
36+
}
37+
38+
export class TinyXception extends NeuralNetwork<TinyXceptionParams> {
39+
40+
private _numMainBlocks: number
41+
42+
constructor(numMainBlocks: number) {
43+
super('TinyXception')
44+
this._numMainBlocks = numMainBlocks
45+
}
46+
47+
public forwardInput(input: NetInput): tf.Tensor4D {
48+
49+
const { params } = this
50+
51+
if (!params) {
52+
throw new Error('TinyXception - load model before inference')
53+
}
54+
55+
return tf.tidy(() => {
56+
const batchTensor = input.toBatchTensor(112, true)
57+
const meanRgb = [122.782, 117.001, 104.298]
58+
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
59+
60+
let out = tf.relu(conv(normalized, params.entry_flow.conv_in, [2, 2]))
61+
out = reductionBlock(out, params.entry_flow.reduction_block_0, false)
62+
out = reductionBlock(out, params.entry_flow.reduction_block_1)
63+
64+
range(this._numMainBlocks, 0, 1).forEach((idx) => {
65+
out = mainBlock(out, params.middle_flow[`main_block_${idx}`])
66+
})
67+
68+
out = reductionBlock(out, params.exit_flow.reduction_block)
69+
out = tf.relu(depthwiseSeparableConv(out, params.exit_flow.separable_conv, [1, 1]))
70+
return out
71+
})
72+
}
73+
74+
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
75+
return this.forwardInput(await toNetInput(input))
76+
}
77+
78+
protected getDefaultModelName(): string {
79+
return 'tiny_xception_model'
80+
}
81+
82+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
83+
return extractParamsFromWeigthMap(weightMap)
84+
}
85+
86+
protected extractParams(weights: Float32Array) {
87+
return extractParams(weights, this._numMainBlocks)
88+
}
89+
}

0 commit comments

Comments
 (0)