Skip to content

Commit b3bcb3e

Browse files
init server + train and test data split
1 parent 0c62e7b commit b3bcb3e

File tree

5 files changed

+100
-0
lines changed

5 files changed

+100
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
export function getImageUrl({ db, img }) {
2+
if (db === 'kaggle') {
3+
return `kaggle-face-expressions-db/${label}/${id}.png`
4+
}
5+
6+
const dirNr = Math.floor(id / NUM_IMAGES_PER_DIR)
7+
return `cropped-faces/jpgs${dirNr + 1}/${id}.jpg`
8+
}

tools/train/faceExpressions/public/testData.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tools/train/faceExpressions/public/trainData.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tools/train/faceExpressions/server.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
require('./.env')
2+
3+
const express = require('express')
4+
const path = require('path')
5+
const fs = require('fs')
6+
7+
const app = express()
8+
9+
const publicDir = path.join(__dirname, './public')
10+
app.use(express.static(publicDir))
11+
app.use(express.static(path.join(__dirname, '../shared')))
12+
app.use(express.static(path.join(__dirname, '../node_modules/file-saver')))
13+
app.use(express.static(path.join(__dirname, '../../../examples/public')))
14+
app.use(express.static(path.join(__dirname, '../../../weights')))
15+
app.use(express.static(path.join(__dirname, '../../../dist')))
16+
app.use(express.static(path.resolve(process.env.DATA_PATH)))
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
const dataDistribution = {
2+
angry: {
3+
db: 1147,
4+
kaggle: 4953
5+
},
6+
disgusted: {
7+
db: 690,
8+
kaggle: 547
9+
},
10+
fearful: {
11+
db: 844,
12+
kaggle: 5121
13+
},
14+
happy: {
15+
db: 8634,
16+
kaggle: 8989
17+
},
18+
neutral: {
19+
db: 1262,
20+
kaggle: 6198
21+
},
22+
sad: {
23+
db: 929,
24+
kaggle: 6077
25+
},
26+
surprised: {
27+
db: 1264,
28+
kaggle: 4002
29+
}
30+
}
31+
32+
const MAX_TRAIN_SAMPLES_PER_CLASS = 2000
33+
34+
require('./.env')
35+
const { shuffleArray } = require('../../../')
36+
const fs = require('fs')
37+
38+
const createImageNameArray = (db, num, ext) =>
39+
Array(num).fill(0)
40+
.map((_, i) => `${i}${ext}`)
41+
.map(img => ({ db, img }))
42+
43+
const splitArray = (arr, idx) => [arr.slice(0, idx), arr.slice(idx)]
44+
45+
const trainData = {}
46+
const testData = {}
47+
48+
Object.keys(dataDistribution)
49+
.forEach(label => {
50+
const { db, kaggle } = dataDistribution[label]
51+
52+
// take max 0.7 percent of db, take rest from kaggle db
53+
const numDb = Math.floor(Math.min(0.7 * MAX_TRAIN_SAMPLES_PER_CLASS, 0.7 * db))
54+
const numKaggle = Math.floor(Math.min(MAX_TRAIN_SAMPLES_PER_CLASS - numDb, 0.7 * kaggle))
55+
56+
const dbImages = shuffleArray(createImageNameArray('db', db, '.jpg'))
57+
const kaggleImages = shuffleArray(createImageNameArray('kaggle', kaggle, '.png'))
58+
59+
const [dbTrain, dbTest] = splitArray(dbImages, numDb)
60+
const [kaggleTrain, kaggleTest] = splitArray(kaggleImages, numKaggle)
61+
62+
console.log()
63+
console.log('%s:', label)
64+
console.log('train data - db: %s, kaggle: %s', dbTrain.length, kaggleTrain.length)
65+
console.log('test data - db: %s, kaggle: %s', dbTest.length, kaggleTest.length)
66+
67+
trainData[label] = dbTrain.concat(kaggleTrain)
68+
testData[label] = dbTest.concat(kaggleTest)
69+
})
70+
71+
fs.writeFileSync('trainData.json', JSON.stringify(trainData))
72+
fs.writeFileSync('testData.json', JSON.stringify(testData))
73+
74+

0 commit comments

Comments
 (0)