Skip to content

Commit da8b102

Browse files
refactored some stuff + train and test script for face expressions net with trained feature extractor
1 parent ebaf7f2 commit da8b102

File tree

16 files changed

+452
-165
lines changed

16 files changed

+452
-165
lines changed

src/faceExpressionNet/FaceExpressionNet.ts

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,35 @@ import { NetInput, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
44
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
55
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
66
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
7-
import { emotionLabels } from './types';
7+
import { faceExpressionLabels } from './types';
88

99
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
1010

11-
public static getEmotionLabel(emotion: string) {
12-
const label = emotionLabels[emotion]
11+
public static getFaceExpressionLabel(faceExpression: string) {
12+
const label = faceExpressionLabels[faceExpression]
1313

1414
if (typeof label !== 'number') {
15-
throw new Error(`getEmotionLabel - no label for emotion: ${emotion}`)
15+
throw new Error(`getFaceExpressionLabel - no label for faceExpression: ${faceExpression}`)
1616
}
1717

1818
return label
1919
}
2020

21-
public static decodeEmotions(probabilities: number[] | Float32Array) {
21+
public static decodeProbabilites(probabilities: number[] | Float32Array) {
2222
if (probabilities.length !== 7) {
23-
throw new Error(`decodeEmotions - expected probabilities.length to be 7, have: ${probabilities.length}`)
23+
throw new Error(`decodeProbabilites - expected probabilities.length to be 7, have: ${probabilities.length}`)
2424
}
2525

26-
return Object.keys(emotionLabels).map(label => ({ label, probability: probabilities[emotionLabels[label]] }))
26+
return Object.keys(faceExpressionLabels)
27+
.map(expression => ({ expression, probability: probabilities[faceExpressionLabels[expression]] }))
2728
}
2829

2930
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
3031
super('FaceExpressionNet', faceFeatureExtractor)
3132
}
3233

33-
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
34-
return tf.tidy(() => {
35-
const out = super.runNet(input)
36-
return tf.softmax(out)
37-
})
38-
}
39-
4034
public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
41-
return tf.tidy(() => this.runNet(input))
35+
return tf.tidy(() => tf.softmax(this.runNet(input)))
4236
}
4337

4438
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
@@ -52,14 +46,7 @@ export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams>
5246
out.dispose()
5347

5448
const predictionsByBatch = probabilitesByBatch
55-
.map(propablities => {
56-
const predictions = {}
57-
FaceExpressionNet.decodeEmotions(propablities as Float32Array)
58-
.forEach(({ label, probability }) => {
59-
predictions[label] = probability
60-
})
61-
return predictions
62-
})
49+
.map(propablities => FaceExpressionNet.decodeProbabilites(propablities as Float32Array))
6350

6451
return netInput.isBatchInput
6552
? predictionsByBatch

src/faceExpressionNet/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export const emotionLabels = {
1+
export const faceExpressionLabels = {
22
neutral: 0,
33
happy: 1,
44
sad: 2,

src/faceProcessor/FaceProcessor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ export abstract class FaceProcessor<
7676

7777
const cIn = this.getClassifierChannelsIn()
7878
const cOut = this.getClassifierChannelsOut()
79-
const classifierWeightSize = (cOut * cIn )+ cOut
79+
const classifierWeightSize = (cOut * cIn ) + cOut
8080

8181
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
8282
const classifierWeights = weights.slice(weights.length - classifierWeightSize)

tools/train/faceExpressions/public/js/commons.js

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
function getImageUrl({ db, label, img }) {
2+
if (db === 'kaggle') {
3+
return `kaggle-face-expressions-db/${label}/${img}`
4+
}
5+
6+
const id = parseInt(img.replace('.png'))
7+
const dirNr = Math.floor(id / 5000)
8+
return `cropped-faces/jpgs${dirNr + 1}/${img}`
9+
}
10+
11+
function prepareDataForEpoch(data) {
12+
return faceapi.shuffleArray(
13+
Object.keys(data).map(label => {
14+
let dataForLabel = data[label].map(data => ({ ...data, label }))
15+
// since train data for "disgusted" have less than 2000 samples
16+
// use some data twice to ensure an even distribution
17+
dataForLabel = label === 'disgusted'
18+
? faceapi.shuffleArray(dataForLabel.concat(dataForLabel).concat(dataForLabel)).slice(0, 2000)
19+
: dataForLabel
20+
return dataForLabel
21+
}).reduce((flat, arr) => arr.concat(flat))
22+
)
23+
}
24+
25+
function getLabelOneHotVector(faceExpression) {
26+
const label = faceapi.FaceExpressionNet.getFaceExpressionLabel(faceExpression)
27+
return Array(7).fill(0).map((_, i) => i === label ? 1 : 0)
28+
}
29+
30+
async function onEpochDone(epoch, params) {
31+
saveWeights(params || window.net, `face_expression_model_${epoch}.weights`)
32+
33+
const loss = window.lossValues[epoch]
34+
saveAs(new Blob([JSON.stringify({ loss, avgLoss: loss / (2000 * 7) })]), `face_expression_model_${epoch}.json`)
35+
36+
}
37+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
async function test(db) {
2+
const faceExpressions = Object.keys(window.testData)
3+
let errors = {}
4+
let preds = {}
5+
let thresh03 = {}
6+
let thresh05 = {}
7+
let thresh08 = {}
8+
let sizes = {}
9+
10+
for (let faceExpression of faceExpressions) {
11+
12+
const container = document.getElementById('container')
13+
const span = document.createElement('div')
14+
container.appendChild(span)
15+
16+
console.log(faceExpression)
17+
18+
const dataForLabel = window.testData[faceExpression]
19+
.filter(data => data.db === db)
20+
.slice(0, window.numDataPerClass)
21+
22+
errors[faceExpression] = 0
23+
preds[faceExpression] = 0
24+
thresh03[faceExpression] = 0
25+
thresh05[faceExpression] = 0
26+
thresh08[faceExpression] = 0
27+
sizes[faceExpression] = dataForLabel.length
28+
29+
30+
for (let [idx, data] of dataForLabel.entries()) {
31+
span.innerHTML = faceExpression + ': ' + faceapi.round(idx / dataForLabel.length) * 100 + '%'
32+
33+
const img = await faceapi.fetchImage(getImageUrl({ ...data, label: faceExpression }))
34+
const pred = await window.net.predictExpressions(img)
35+
const bestPred = pred
36+
.reduce((best, curr) => curr.probability < best.probability ? best : curr)
37+
38+
const { probability } = pred.find(p => p.expression === faceExpression)
39+
thresh03[faceExpression] += (probability > 0.3 ? 1 : 0)
40+
thresh05[faceExpression] += (probability > 0.5 ? 1 : 0)
41+
thresh08[faceExpression] += (probability > 0.8 ? 1 : 0)
42+
errors[faceExpression] += 1 - probability
43+
preds[faceExpression] += (bestPred.expression === faceExpression ? 1 : 0)
44+
}
45+
46+
span.innerHTML = faceExpression + ': 100%'
47+
48+
}
49+
50+
const totalError = faceExpressions.reduce((err, faceExpression) => err + errors[faceExpression], 0)
51+
52+
const relative = (obj) => {
53+
res = {}
54+
Object.keys(sizes).forEach((faceExpression) => {
55+
res[faceExpression] = faceapi.round(
56+
obj[faceExpression] / sizes[faceExpression]
57+
)
58+
})
59+
return res
60+
}
61+
62+
console.log('done...')
63+
console.log('test set size:', sizes)
64+
console.log('preds:', relative(preds))
65+
console.log('thresh03:', relative(thresh03))
66+
console.log('thresh05:', relative(thresh05))
67+
console.log('thresh08:', relative(thresh08))
68+
console.log('errors:', errors)
69+
console.log('total error:', totalError)
70+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
async function train() {
2+
await load()
3+
4+
for (let epoch = startEpoch; epoch < Infinity; epoch++) {
5+
6+
if (epoch !== startEpoch) {
7+
// ugly hack to wait for loss datas for that epoch to be resolved
8+
setTimeout(
9+
() => onEpochDone(
10+
epoch - 1,
11+
new Float32Array(Array.from(window.net.faceFeatureExtractor.serializeParams())
12+
.concat(Array.from(window.net.serializeParams()))
13+
)
14+
),
15+
10000
16+
)
17+
}
18+
window.lossValues[epoch] = 0
19+
20+
const shuffledInputs = prepareDataForEpoch(window.trainData)
21+
console.log(shuffledInputs)
22+
23+
for (let dataIdx = 0; dataIdx < shuffledInputs.length; dataIdx += window.batchSize) {
24+
const tsIter = Date.now()
25+
26+
const batchData = shuffledInputs.slice(dataIdx, dataIdx + window.batchSize)
27+
const bImages = await Promise.all(
28+
batchData
29+
.map(data => getImageUrl(data))
30+
.map(imgUrl => faceapi.fetchImage(imgUrl))
31+
)
32+
const bOneHotVectors = batchData
33+
.map(data => getLabelOneHotVector(data.label))
34+
35+
let tsBackward = Date.now()
36+
let tsForward = Date.now()
37+
const netInput = await faceapi.toNetInput(bImages)
38+
tsForward = Date.now() - tsForward
39+
40+
const loss = optimizer.minimize(() => {
41+
tsBackward = Date.now()
42+
const labels = tf.tensor2d(bOneHotVectors)
43+
const out = window.net.runNet(netInput)
44+
45+
const loss = tf.losses.softmaxCrossEntropy(
46+
labels,
47+
out,
48+
tf.Reduction.MEAN
49+
)
50+
51+
return loss
52+
}, true)
53+
tsBackward = Date.now() - tsBackward
54+
55+
// start next iteration without waiting for loss data
56+
57+
loss.data().then(data => {
58+
const lossValue = data[0]
59+
window.lossValues[epoch] += lossValue
60+
window.withLogging && log(`epoch ${epoch}, dataIdx ${dataIdx} - loss: ${lossValue}, ( ${window.lossValues[epoch]})`)
61+
loss.dispose()
62+
})
63+
64+
window.withLogging && log(`epoch ${epoch}, dataIdx ${dataIdx} - forward: ${tsForward} ms, backprop: ${tsBackward} ms, iter: ${Date.now() - tsIter} ms`)
65+
if (window.logsilly) {
66+
log(`fetch: ${tsFetch} ms, pts: ${tsFetchPts} ms, jpgs: ${tsFetchJpgs} ms, bufferToImage: ${tsBufferToImage} ms`)
67+
}
68+
if (window.iterDelay) {
69+
await delay(window.iterDelay)
70+
} else {
71+
await tf.nextFrame()
72+
}
73+
}
74+
}
75+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<script src="face-api.js"></script>
5+
<script src="FileSaver.js"></script>
6+
<script src="commons.js"></script>
7+
<script src="js/faceExpressionsCommons.js"></script>
8+
<script src="js/test.js"></script>
9+
</head>
10+
<body>
11+
<div id="container"></div>
12+
13+
<script>
14+
tf = faceapi.tf
15+
16+
window.numDataPerClass = Infinity
17+
18+
// load the FaceLandmark68Net and use it's feature extractor since we only
19+
// train the output layer of the FaceExpressionNet
20+
window.net = new faceapi.FaceExpressionNet()
21+
22+
// uri to weights file of last checkpoint
23+
const modelCheckpoint = 'tmp/full/face_expression_model_120.weights'
24+
25+
async function load() {
26+
window.testData = await faceapi.fetchJson('testData.json')
27+
28+
// fetch the actual output layer weights
29+
const weights = await faceapi.fetchNetWeights(modelCheckpoint)
30+
await window.net.load(weights)
31+
32+
console.log('loaded')
33+
}
34+
35+
load()
36+
37+
</script>
38+
39+
</body>
40+
</html>
41+

0 commit comments

Comments
 (0)