Skip to content

Commit aa7dd87

Browse files
init train data fetching
1 parent ca47cbb commit aa7dd87

File tree

6 files changed

+82
-17
lines changed

6 files changed

+82
-17
lines changed

tools/train/faceLandmarks/faceLandmarksTrain.js

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
async function promiseSequential(promises) {
2-
const curr = promises[0]
3-
if (!curr) {
4-
return
5-
}
6-
7-
await curr()
8-
return promiseSequential(promises.slice(1))
9-
}
10-
111
async function trainStep(batchCreators) {
122
await promiseSequential(batchCreators.map((batchCreator, dataIdx) => async () => {
133

tools/train/serveFaceLandmarks.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ const app = express()
77

88
const publicDir = path.join(__dirname, './faceLandmarks')
99
app.use(express.static(publicDir))
10+
app.use(express.static(path.join(__dirname, './shared')))
1011
app.use(express.static(path.join(__dirname, './node_modules/file-saver')))
1112
app.use(express.static(path.join(__dirname, '../../examples/public')))
1213
app.use(express.static(path.join(__dirname, '../../weights')))

tools/train/serveTinyYolov2.js

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,26 @@ require('./tinyYolov2/.env')
22

33
const express = require('express')
44
const path = require('path')
5+
const fs = require('fs')
56

67
const app = express()
78

89
const publicDir = path.join(__dirname, './tinyYolov2')
910
app.use(express.static(publicDir))
11+
app.use(express.static(path.join(__dirname, './shared')))
1012
app.use(express.static(path.join(__dirname, './node_modules/file-saver')))
1113
app.use(express.static(path.join(__dirname, '../../examples/public')))
1214
app.use(express.static(path.join(__dirname, '../../weights')))
1315
app.use(express.static(path.join(__dirname, '../../dist')))
1416

1517
const trainDataPath = path.resolve(process.env.TRAIN_DATA_PATH)
18+
const imagesPath = path.join(trainDataPath, './final_images')
19+
const detectionsPath = path.join(trainDataPath, './final_detections')
20+
const detectionFilenames = fs.readdirSync(detectionsPath)
21+
1622
app.use(express.static(trainDataPath))
1723

18-
//app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html')))
19-
//app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'tinyYolov2FaceDetectionVideo.html')))
20-
app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'testLoss.html')))
24+
app.get('/detection_filenames', (req, res) => res.status(202).send(detectionFilenames))
25+
app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html')))
2126

2227
app.listen(3000, () => console.log('Listening on port 3000!'))

tools/train/faceLandmarks/trainUtils.js renamed to tools/train/shared/trainUtils.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1+
async function promiseSequential(promises) {
2+
const curr = promises[0]
3+
if (!curr) {
4+
return
5+
}
16

7+
await curr()
8+
return promiseSequential(promises.slice(1))
9+
}
210

311
// https://stackoverflow.com/questions/6274339/how-can-i-shuffle-an-array
412
function shuffle(a) {

tools/train/tinyYolov2/train.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@
2929
return new Float32Array(await (await fetch(uri)).arrayBuffer())
3030
}
3131

32-
async function getTrainData() {
33-
// TBD
32+
async function fetchDetectionFilenames() {
33+
return fetch('/detection_filenames').then(res => res.json())
3434
}
3535

3636
async function run() {
3737
const weights = await loadNetWeights(weightsUrl)
3838
window.net = new faceapi.TinyYolov2(true)
3939
window.net.load(weights)
40-
window.trainData = await getTrainData()
4140
window.net.variable()
41+
window.detectionFilenames = await fetchDetectionFilenames()
4242
}
4343

4444
/*
@@ -68,7 +68,7 @@
6868
async function train(batchSize = 1) {
6969
for (let i = 0; i < trainSteps; i++) {
7070
console.log('step', i)
71-
const batchCreators = createBatchCreators(shuffle(window.trainData), batchSize)
71+
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
7272
let ts = Date.now()
7373
await trainStep(batchCreators)
7474
ts = Date.now() - ts

tools/train/tinyYolov2/train.js

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
async function trainStep(batchCreators) {
2+
await promiseSequential(batchCreators.map((batchCreator, dataIdx) => async () => {
3+
4+
const { batchInput, groundTruthBoxes } = await batchCreator()
5+
/*
6+
let ts = Date.now()
7+
const cost = optimizer.minimize(() => {
8+
const out = window.trainNet.forwardInput(batchInput.managed())
9+
const loss = lossFunction(
10+
landmarksBatchTensor,
11+
out
12+
)
13+
return loss
14+
}, true)
15+
16+
ts = Date.now() - ts
17+
console.log(`loss[${dataIdx}]: ${await cost.data()}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
18+
19+
landmarksBatchTensor.dispose()
20+
cost.dispose()
21+
22+
await tf.nextFrame()
23+
}))
24+
*/
25+
}
26+
27+
function createBatchCreators(batchSize) {
28+
if (batchSize < 1) {
29+
throw new Error('invalid batch size: ' + batchSize)
30+
}
31+
32+
const batches = []
33+
const pushToBatch = (remaining) => {
34+
if (remaining.length) {
35+
batches.push(remaining.slice(0, batchSize))
36+
pushToBatch(remaining.
37+
slice(batchSize))
38+
}
39+
return batches
40+
}
41+
42+
pushToBatch(window.detectionFilenames)
43+
44+
const batchCreators = batches.map(detectionFilenames => async () => {
45+
const imgs = detectionFilenames.map(
46+
detectionFilenames.map(file => fetch(file).then(res => res.json()))
47+
)
48+
const groundTruthBoxes = await Promise.all(
49+
detectionFilenames.map(async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', ''))))
50+
)
51+
52+
const batchInput = await faceapi.toNetInput(imgs)
53+
54+
return {
55+
batchInput,
56+
groundTruthBoxes
57+
}
58+
})
59+
60+
return batchCreators
61+
}

0 commit comments

Comments
 (0)