Skip to content

Commit 896371a

Browse files
square images before converting them to tensors to prevent gpu memory leak issue
1 parent 5e50b9d commit 896371a

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

tools/train/tinyYolov2/train.html

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424

2525
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
2626

27-
//const weightsUrl = '/tmp/tmp_2_count_41000.weights'
2827
const fromEpoch = 0
2928

3029
const trainOnlyMultibox = false
3130

31+
const trainSizes = [160, 224, 320, 416]
32+
//const trainSizes = [608]
33+
3234
window.debug = false
3335
window.logTrainSteps = true
3436

@@ -38,8 +40,8 @@
3840
window.noObjectScale = 1
3941
window.coordScale = 1
4042

41-
const rescaleEveryNthBatch = Infinity
42-
window.saveEveryNthDataIdx = 100
43+
const rescaleEveryNthBatch = 100
44+
window.saveEveryNthDataIdx = trainSizes.length * rescaleEveryNthBatch
4345
window.trainSteps = 4000
4446
//window.optimizer = tf.train.sgd(0.001)
4547
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
@@ -78,26 +80,23 @@
7880
console.log('ready')
7981
}
8082

81-
//const trainSizes = [224, 320, 416]
82-
const trainSizes = [608]
83-
8483
function logLossChange(lossType) {
8584
const { currentLoss, prevLoss, detectionFilenames } = window
8685
log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`)
8786
}
8887

8988
window.count = 0
9089

91-
function _onBatchProcessed(dataIdx, inputSize) {
90+
function onBatchProcessed(dataIdx, inputSize) {
9291
window.count++
9392
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
9493
console.log('dataIdx', dataIdx)
9594
if ((window.count % saveEveryNthDataIdx) === 0) {
96-
saveWeights(window.net, `tmp_2_count_${window.count}.weights`)
95+
saveWeights(window.net, `tmp_multiscale_count_${window.count}.weights`)
9796
}
9897
}
9998

100-
function onBatchProcessed(dataIdx, inputSize) {
99+
function _onBatchProcessed(dataIdx, inputSize) {
101100
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
102101
console.log('idx', idx)
103102
if ((idx % saveEveryNthDataIdx) === 0) {

tools/train/tinyYolov2/train.js

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
const batchIdx = 0
22

3-
function minimize(groundTruthBoxes, batchInput, inputSize, batch) {
3+
function minimize(groundTruthBoxes, batchInput, inputSize, batch, { reshapedImgDims, paddings }) {
44
const filename = batch.filenames[batchIdx]
55
const { dataIdx } = batch
66

@@ -16,8 +16,8 @@ function minimize(groundTruthBoxes, batchInput, inputSize, batch) {
1616
} = computeLoss(
1717
outTensor,
1818
groundTruthBoxes,
19-
batchInput.getReshapedInputDimensions(batchIdx),
20-
batchInput.getRelativePaddings(batchIdx)
19+
reshapedImgDims,
20+
paddings
2121
)
2222

2323
const losses = {
@@ -47,6 +47,35 @@ function minimize(groundTruthBoxes, batchInput, inputSize, batch) {
4747
}, true)
4848
}
4949

50+
function imageToSquare(img) {
51+
const scale = 608 / Math.max(img.height, img.width)
52+
const width = scale * img.width
53+
const height = scale * img.height
54+
55+
const canvas1 = faceapi.createCanvasFromMedia(img)
56+
const targetCanvas = faceapi.createCanvas({ width: 608, height: 608 })
57+
targetCanvas.getContext('2d').putImageData(canvas1.getContext('2d').getImageData(0, 0, width, height), 0, 0)
58+
return targetCanvas
59+
}
60+
61+
function getPaddingsAndReshapedSize(img, inputSize) {
62+
const [h, w] = [img.height, img.width]
63+
const maxDim = Math.max(h, w)
64+
65+
const f = inputSize / maxDim
66+
const reshapedImgDims = {
67+
height: Math.floor(h * f),
68+
width: Math.floor(w * f)
69+
}
70+
71+
const paddings = new faceapi.Point(
72+
maxDim / img.width,
73+
maxDim / img.height
74+
)
75+
76+
return { paddings, reshapedImgDims }
77+
}
78+
5079
async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatchProcessed = () => {}) {
5180

5281
async function step(currentBatchCreators) {
@@ -61,7 +90,11 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
6190
const batch = await batchCreator()
6291
const { imgs, groundTruthBoxes, filenames, dataIdx } = batch
6392

64-
const batchInput = await faceapi.toNetInput(imgs)
93+
const img = imgs[0]
94+
const { reshapedImgDims, paddings } = getPaddingsAndReshapedSize(img, inputSize)
95+
const squareImg = imageToSquare(img)
96+
97+
const batchInput = await faceapi.toNetInput(squareImg)
6598

6699
const [imgHeight, imgWidth] = batchInput.inputs[batchIdx].shape
67100

@@ -90,7 +123,8 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
90123
}
91124

92125
let ts = Date.now()
93-
const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch)
126+
const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch, { reshapedImgDims, paddings })
127+
94128
ts = Date.now() - ts
95129
if (window.logTrainSteps) {
96130
log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms`)

tools/train/tinyYolov2/verify.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@
144144
const startIdx416 = 31050
145145
const startIdx608 = 16520
146146

147-
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
147+
//const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
148+
const weightsUrl = `/tmp/overfit_count_1500.weights`
148149

149150
const weights = await loadNetWeights(weightsUrl)
150151
window.net = new faceapi.TinyYolov2(true)

0 commit comments

Comments
 (0)