Skip to content

Commit 883f20c

Browse files
init face expression model training script
1 parent 43bf688 commit 883f20c

File tree

14 files changed

+303
-23
lines changed

14 files changed

+303
-23
lines changed

src/faceExpressionNet/FaceExpressionNet.ts

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,57 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { NetInput, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
3+
14
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
25
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
36
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
7+
import { EmotionLabels } from './types';
48

59
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
610

7-
constructor(faceFeatureExtractor: FaceFeatureExtractor) {
11+
public static getEmotionLabel(emotion: string) {
12+
const label = EmotionLabels[emotion.toUpperCase()]
13+
14+
if (typeof label !== 'number') {
15+
throw new Error(`getEmotionLabel - no label for emotion: ${emotion}`)
16+
}
17+
18+
return label
19+
}
20+
21+
public static decodeEmotions(probabilities: number[] | Float32Array) {
22+
if (probabilities.length !== 7) {
23+
throw new Error(`decodeEmotions - expected probabilities.length to be 7, have: ${probabilities.length}`)
24+
}
25+
return Object.keys(EmotionLabels).map(label => ({ label, probability: probabilities[EmotionLabels[label]] }))
26+
}
27+
28+
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
829
super('FaceExpressionNet', faceFeatureExtractor)
930
}
1031

32+
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
33+
return tf.tidy(() => {
34+
const out = super.runNet(input)
35+
return tf.softmax(out)
36+
})
37+
}
38+
39+
public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
40+
return tf.tidy(() => this.runNet(input))
41+
}
42+
43+
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
44+
return this.forwardInput(await toNetInput(input))
45+
}
46+
47+
public async predictExpressions(input: TNetInput) {
48+
const out = await this.forward(input)
49+
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(t => t.data()))
50+
out.dispose()
51+
52+
return probabilitesByBatch.map(propablities => FaceExpressionNet.decodeEmotions(propablities as Float32Array))
53+
}
54+
1155
public dispose(throwOnRedispose: boolean = true) {
1256
this.faceFeatureExtractor.dispose(throwOnRedispose)
1357
super.dispose(throwOnRedispose)

src/faceExpressionNet/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
export * from './FaceExpressionNet';
1+
export * from './FaceExpressionNet';
2+
export * from './types';

src/faceExpressionNet/types.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export enum EmotionLabels {
2+
NEUTRAL = 0,
3+
HAPPY = 1,
4+
SAD = 2,
5+
ANGRY = 3,
6+
FEARFUL = 4,
7+
DISGUSTED = 5,
8+
SURPRISED = 6
9+
}

src/faceFeatureExtractor/FaceFeatureExtractor.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import * as tf from '@tensorflow/tfjs-core';
2-
import { NetInput, NeuralNetwork, normalize } from 'tfjs-image-recognition-base';
2+
import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
33
import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
44

55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParams } from './extractParams';
77
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
8-
import { DenseBlock4Params, IFaceFeatureExtractor, FaceFeatureExtractorParams } from './types';
8+
import { DenseBlock4Params, FaceFeatureExtractorParams, IFaceFeatureExtractor } from './types';
99

1010
function denseBlock(
1111
x: tf.Tensor4D,
@@ -39,7 +39,7 @@ export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorPara
3939
super('FaceFeatureExtractor')
4040
}
4141

42-
public forward(input: NetInput): tf.Tensor4D {
42+
public forwardInput(input: NetInput): tf.Tensor4D {
4343

4444
const { params } = this
4545

@@ -62,6 +62,10 @@ export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorPara
6262
})
6363
}
6464

65+
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
66+
return this.forwardInput(await toNetInput(input))
67+
}
68+
6569
protected getDefaultModelName(): string {
6670
return 'face_feature_extractor_model'
6771
}

src/faceFeatureExtractor/TinyFaceFeatureExtractor.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import * as tf from '@tensorflow/tfjs-core';
2-
import { NetInput, NeuralNetwork, normalize } from 'tfjs-image-recognition-base';
2+
import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
33
import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
44

55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
@@ -36,7 +36,7 @@ export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtra
3636
super('TinyFaceFeatureExtractor')
3737
}
3838

39-
public forward(input: NetInput): tf.Tensor4D {
39+
public forwardInput(input: NetInput): tf.Tensor4D {
4040

4141
const { params } = this
4242

@@ -58,6 +58,10 @@ export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtra
5858
})
5959
}
6060

61+
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
62+
return this.forwardInput(await toNetInput(input))
63+
}
64+
6165
protected getDefaultModelName(): string {
6266
return 'face_landmark_68_tiny_model'
6367
}

src/faceFeatureExtractor/types.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import * as tf from '@tensorflow/tfjs-core';
2-
import { NetInput, NeuralNetwork } from 'tfjs-image-recognition-base';
2+
import { NetInput, NeuralNetwork, TNetInput } from 'tfjs-image-recognition-base';
33
import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
44

55
export type ConvWithBatchNormParams = BatchNormParams & {
@@ -42,5 +42,6 @@ export type FaceFeatureExtractorParams = {
4242
}
4343

4444
export interface IFaceFeatureExtractor<TNetParams extends TinyFaceFeatureExtractorParams | FaceFeatureExtractorParams> extends NeuralNetwork<TNetParams> {
45-
forward(input: NetInput): tf.Tensor4D
45+
forwardInput(input: NetInput): tf.Tensor4D
46+
forward(input: TNetInput): Promise<tf.Tensor4D>
4647
}

src/faceProcessor/FaceProcessor.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ export abstract class FaceProcessor<
4242

4343
return tf.tidy(() => {
4444
const bottleneckFeatures = input instanceof NetInput
45-
? this.faceFeatureExtractor.forward(input)
45+
? this.faceFeatureExtractor.forwardInput(input)
4646
: input
4747
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
4848
})
@@ -53,6 +53,16 @@ export abstract class FaceProcessor<
5353
super.dispose(throwOnRedispose)
5454
}
5555

56+
public loadClassifierParams(weights: Float32Array) {
57+
const { params, paramMappings } = this.extractClassifierParams(weights)
58+
this._params = params
59+
this._paramMappings = paramMappings
60+
}
61+
62+
public extractClassifierParams(weights: Float32Array) {
63+
return extractParams(weights, this.getClassifierChannelsIn(), this.getClassifierChannelsOut())
64+
}
65+
5666
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
5767

5868
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
@@ -72,6 +82,6 @@ export abstract class FaceProcessor<
7282
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
7383

7484
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
75-
return extractParams(classifierWeights, cIn, cOut)
85+
return this.extractClassifierParams(classifierWeights)
7686
}
7787
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.0"> </script>
5+
<script src="../../node_modules/file-saver/FileSaver.js"></script>
6+
</head>
7+
<body>
8+
9+
<div class="row side-by-side">
10+
<button
11+
class="waves-effect waves-light btn"
12+
onclick="save()"
13+
>
14+
Save
15+
</button>
16+
</div>
17+
18+
<script>
19+
function toDataArray(tensor) {
20+
return Array.from(tensor.dataSync())
21+
}
22+
23+
function flatten(arrs) {
24+
return arrs.reduce((flat, arr) => flat.concat(arr))
25+
}
26+
27+
function initWeights(initializer) {
28+
const numOutputFilters = 256
29+
const outputSize = 7
30+
31+
const weights = toDataArray(initializer.apply([1, 1, numOutputFilters, 7]))
32+
const bias = toDataArray(tf.zeros([outputSize]))
33+
34+
return new Float32Array(weights.concat(bias))
35+
}
36+
37+
function save() {
38+
const initialWeights = initWeights(
39+
tf.initializers.glorotNormal()
40+
)
41+
saveAs(new Blob([initialWeights]), `initial_glorot.weights`)
42+
}
43+
44+
</script>
45+
</body>
46+
</html>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function getImageUrl({ db, label, img }) {
2+
if (db === 'kaggle') {
3+
return `kaggle-face-expressions-db/${label}/${img}`
4+
}
5+
6+
const id = parseInt(img.replace('.png'))
7+
const dirNr = Math.floor(id / 5000)
8+
return `cropped-faces/jpgs${dirNr + 1}/${img}`
9+
}
10+

tools/train/faceExpressions/public/js/fetchData.js

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)