Skip to content

Commit 9af582a

Browse files
use env wrapper for browser related code
1 parent 3ed6b10 commit 9af582a

16 files changed

+84
-77
lines changed

src/dom/drawLandmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { getContext2dOrThrow, getDefaultDrawOptions, resolveInput } from 'tfjs-image-recognition-base';
1+
import { env, getContext2dOrThrow, getDefaultDrawOptions, resolveInput } from 'tfjs-image-recognition-base';
22

33
import { FaceLandmarks } from '../classes/FaceLandmarks';
44
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
@@ -11,7 +11,7 @@ export function drawLandmarks(
1111
options?: DrawLandmarksOptions
1212
) {
1313
const canvas = resolveInput(canvasArg)
14-
if (!(canvas instanceof HTMLCanvasElement)) {
14+
if (!(canvas instanceof env.getEnv().Canvas)) {
1515
throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement')
1616
}
1717

src/dom/extractFaces.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
} from 'tfjs-image-recognition-base';
99

1010
import { FaceDetection } from '../classes/FaceDetection';
11+
import { env } from 'tfjs-image-recognition-base';
1112

1213
/**
1314
* Extracts the image regions containing the detected faces.
@@ -21,17 +22,19 @@ export async function extractFaces(
2122
detections: Array<FaceDetection | Rect>
2223
): Promise<HTMLCanvasElement[]> {
2324

25+
const { Canvas } = env.getEnv()
26+
2427
let canvas = input as HTMLCanvasElement
2528

26-
if (!(input instanceof HTMLCanvasElement)) {
29+
if (!(input instanceof Canvas)) {
2730
const netInput = await toNetInput(input)
2831

2932
if (netInput.batchSize > 1) {
3033
throw new Error('extractFaces - batchSize > 1 not supported')
3134
}
3235

3336
const tensorOrCanvas = netInput.getInput(0)
34-
canvas = tensorOrCanvas instanceof HTMLCanvasElement
37+
canvas = tensorOrCanvas instanceof Canvas
3538
? tensorOrCanvas
3639
: await imageTensorToCanvas(tensorOrCanvas)
3740
}

src/faceLandmarkNet/FaceLandmark68Net.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
44

55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParams } from './extractParams';
7+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
78
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
89
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { loadQuantizedParams } from './loadQuantizedParams';
1010
import { DenseBlock4Params, NetParams } from './types';
1111

1212
function denseBlock(
@@ -64,10 +64,13 @@ export class FaceLandmark68Net extends FaceLandmark68NetBase<NetParams> {
6464
})
6565
}
6666

67-
protected loadQuantizedParams(uri: string | undefined) {
68-
return loadQuantizedParams(uri)
67+
protected getDefaultModelName(): string {
68+
return 'face_landmark_68_model'
6969
}
7070

71+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
72+
return extractParamsFromWeigthMap(weightMap)
73+
}
7174

7275
protected extractParams(weights: Float32Array) {
7376
return extractParams(weights)

src/faceLandmarkNet/FaceLandmark68NetBase.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { IDimensions, isEven, NetInput, NeuralNetwork, Point, TNetInput, toNetIn
33

44
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
55

6-
export class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
6+
export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
77

88
// TODO: make super.name protected
99
private __name: string
@@ -13,9 +13,7 @@ export class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
1313
this.__name = _name
1414
}
1515

16-
public runNet(_: NetInput): tf.Tensor2D {
17-
throw new Error(`${this.__name} - runNet not implemented`)
18-
}
16+
public abstract runNet(netInput: NetInput): tf.Tensor2D
1917

2018
public postProcess(output: tf.Tensor2D, inputSize: number, originalDimensions: IDimensions[]): tf.Tensor2D {
2119

src/faceLandmarkNet/FaceLandmark68TinyNet.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParamsTiny } from './extractParamsTiny';
77
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
88
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { loadQuantizedParamsTiny } from './loadQuantizedParamsTiny';
9+
import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny';
1010
import { DenseBlock3Params, TinyNetParams } from './types';
1111

1212
function denseBlock(
@@ -60,8 +60,12 @@ export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<TinyNetParams>
6060
})
6161
}
6262

63-
protected loadQuantizedParams(uri: string | undefined) {
64-
return loadQuantizedParamsTiny(uri)
63+
protected getDefaultModelName(): string {
64+
return 'face_landmark_68_tiny_model'
65+
}
66+
67+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
68+
return extractParamsFromWeigthMapTiny(weightMap)
6569
}
6670

6771
protected extractParams(weights: Float32Array) {

src/faceLandmarkNet/loadQuantizedParams.ts renamed to src/faceLandmarkNet/extractParamsFromWeigthMap.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import { disposeUnusedWeightTensors, loadWeightMap, ParamMapping } from 'tfjs-image-recognition-base';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
23

34
import { loadParamsFactory } from './loadParamsFactory';
45
import { NetParams } from './types';
56

6-
const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
7+
export function extractParamsFromWeigthMap(
8+
weightMap: tf.NamedTensorMap
9+
): { params: NetParams, paramMappings: ParamMapping[] } {
710

8-
export async function loadQuantizedParams(
9-
uri: string | undefined
10-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
11-
12-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
1311
const paramMappings: ParamMapping[] = []
1412

1513
const {

src/faceLandmarkNet/loadQuantizedParamsTiny.ts renamed to src/faceLandmarkNet/extractParamsFromWeigthMapTiny.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import { disposeUnusedWeightTensors, loadWeightMap, ParamMapping } from 'tfjs-image-recognition-base';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
23

34
import { loadParamsFactory } from './loadParamsFactory';
45
import { TinyNetParams } from './types';
56

6-
const DEFAULT_MODEL_NAME = 'face_landmark_68_tiny_model'
7+
export function extractParamsFromWeigthMapTiny(
8+
weightMap: tf.NamedTensorMap
9+
): { params: TinyNetParams, paramMappings: ParamMapping[] } {
710

8-
export async function loadQuantizedParamsTiny(
9-
uri: string | undefined
10-
): Promise<{ params: TinyNetParams, paramMappings: ParamMapping[] }> {
11-
12-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
1311
const paramMappings: ParamMapping[] = []
1412

1513
const {

src/faceRecognitionNet/FaceRecognitionNet.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-
33

44
import { convDown } from './convLayer';
55
import { extractParams } from './extractParams';
6-
import { loadQuantizedParams } from './loadQuantizedParams';
6+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
77
import { residual, residualDown } from './residualLayer';
88
import { NetParams } from './types';
99

@@ -78,8 +78,12 @@ export class FaceRecognitionNet extends NeuralNetwork<NetParams> {
7878
: faceDescriptorsForBatch[0]
7979
}
8080

81-
protected loadQuantizedParams(uri: string | undefined) {
82-
return loadQuantizedParams(uri)
81+
protected getDefaultModelName(): string {
82+
return 'face_recognition_model'
83+
}
84+
85+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
86+
return extractParamsFromWeigthMap(weightMap)
8387
}
8488

8589
protected extractParams(weights: Float32Array) {

src/faceRecognitionNet/loadQuantizedParams.ts renamed to src/faceRecognitionNet/extractParamsFromWeigthMap.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import {
99

1010
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
1111

12-
const DEFAULT_MODEL_NAME = 'face_recognition_model'
13-
1412
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
1513

1614
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
@@ -46,11 +44,10 @@ function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
4644

4745
}
4846

49-
export async function loadQuantizedParams(
50-
uri: string | undefined
51-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
47+
export function extractParamsFromWeigthMap(
48+
weightMap: tf.NamedTensorMap
49+
): { params: NetParams, paramMappings: ParamMapping[] } {
5250

53-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
5451
const paramMappings: ParamMapping[] = []
5552

5653
const {

src/mtcnn/Mtcnn.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import { FaceLandmarks5 } from '../classes/FaceLandmarks5';
77
import { bgrToRgbTensor } from './bgrToRgbTensor';
88
import { CELL_SIZE } from './config';
99
import { extractParams } from './extractParams';
10+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
1011
import { getSizesForScale } from './getSizesForScale';
11-
import { loadQuantizedParams } from './loadQuantizedParams';
1212
import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions';
1313
import { pyramidDown } from './pyramidDown';
1414
import { stage1 } from './stage1';
@@ -146,9 +146,12 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
146146
)
147147
}
148148

149-
// none of the param tensors are quantized yet
150-
protected loadQuantizedParams(uri: string | undefined) {
151-
return loadQuantizedParams(uri)
149+
protected getDefaultModelName(): string {
150+
return 'mtcnn_model'
151+
}
152+
153+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
154+
return extractParamsFromWeigthMap(weightMap)
152155
}
153156

154157
protected extractParams(weights: Float32Array) {

src/mtcnn/loadQuantizedParams.ts renamed to src/mtcnn/extractParamsFromWeigthMap.ts

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import * as tf from '@tensorflow/tfjs-core';
2-
import {
3-
disposeUnusedWeightTensors,
4-
extractWeightEntryFactory,
5-
loadWeightMap,
6-
ParamMapping,
7-
} from 'tfjs-image-recognition-base';
2+
import { disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from 'tfjs-image-recognition-base';
83
import { ConvParams, FCParams } from 'tfjs-tiny-yolov2';
94

105
import { NetParams, ONetParams, PNetParams, RNetParams, SharedParams } from './types';
116

12-
const DEFAULT_MODEL_NAME = 'mtcnn_model'
13-
147
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
158

169
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
@@ -87,11 +80,10 @@ function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
8780

8881
}
8982

90-
export async function loadQuantizedParams(
91-
uri: string | undefined
92-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
83+
export function extractParamsFromWeigthMap(
84+
weightMap: tf.NamedTensorMap
85+
): { params: NetParams, paramMappings: ParamMapping[] } {
9386

94-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
9587
const paramMappings: ParamMapping[] = []
9688

9789
const {

src/ssdMobilenetv1/SsdMobilenetv1.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { NetInput, NeuralNetwork, Rect, TNetInput, toNetInput } from 'tfjs-image
33

44
import { FaceDetection } from '../classes/FaceDetection';
55
import { extractParams } from './extractParams';
6-
import { loadQuantizedParams } from './loadQuantizedParams';
6+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
77
import { mobileNetV1 } from './mobileNetV1';
88
import { nonMaxSuppression } from './nonMaxSuppression';
99
import { outputLayer } from './outputLayer';
@@ -116,8 +116,12 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
116116
return results
117117
}
118118

119-
protected loadQuantizedParams(uri: string | undefined) {
120-
return loadQuantizedParams(uri)
119+
protected getDefaultModelName(): string {
120+
return 'ssd_mobilenetv1_model'
121+
}
122+
123+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
124+
return extractParamsFromWeigthMap(weightMap)
121125
}
122126

123127
protected extractParams(weights: Float32Array) {

src/ssdMobilenetv1/loadQuantizedParams.ts renamed to src/ssdMobilenetv1/extractParamsFromWeigthMap.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@ import {
33
disposeUnusedWeightTensors,
44
extractWeightEntryFactory,
55
isTensor3D,
6-
loadWeightMap,
76
ParamMapping,
87
} from 'tfjs-image-recognition-base';
98
import { ConvParams } from 'tfjs-tiny-yolov2';
109

1110
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
1211

13-
const DEFAULT_MODEL_NAME = 'ssd_mobilenetv1_model'
14-
1512
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
1613

1714
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
@@ -114,11 +111,10 @@ function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
114111
}
115112
}
116113

117-
export async function loadQuantizedParams(
118-
uri: string | undefined
119-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
114+
export function extractParamsFromWeigthMap(
115+
weightMap: tf.NamedTensorMap
116+
): { params: NetParams, paramMappings: ParamMapping[] } {
120117

121-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
122118
const paramMappings: ParamMapping[] = []
123119

124120
const {

src/tinyFaceDetector/TinyFaceDetector.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import { Point, TNetInput } from 'tfjs-image-recognition-base';
2-
import { TinyYolov2 as TinyYolov2Base, ITinyYolov2Options } from 'tfjs-tiny-yolov2';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { ParamMapping, Point, TNetInput } from 'tfjs-image-recognition-base';
3+
import { ITinyYolov2Options, TinyYolov2 as TinyYolov2Base } from 'tfjs-tiny-yolov2';
4+
import { TinyYolov2NetParams } from 'tfjs-tiny-yolov2/build/commonjs/tinyYolov2/types';
35

46
import { FaceDetection } from '../classes';
5-
import { BOX_ANCHORS, DEFAULT_MODEL_NAME, IOU_THRESHOLD, MEAN_RGB } from './const';
7+
import { BOX_ANCHORS, IOU_THRESHOLD, MEAN_RGB } from './const';
68

79
export class TinyFaceDetector extends TinyYolov2Base {
810

@@ -29,8 +31,11 @@ export class TinyFaceDetector extends TinyYolov2Base {
2931
return objectDetections.map(det => new FaceDetection(det.score, det.relativeBox, { width: det.imageWidth, height: det.imageHeight }))
3032
}
3133

32-
protected loadQuantizedParams(modelUri: string | undefined) {
33-
const defaultModelName = DEFAULT_MODEL_NAME
34-
return super.loadQuantizedParams(modelUri, defaultModelName) as any
34+
protected getDefaultModelName(): string {
35+
return 'tiny_face_detector_model'
36+
}
37+
38+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
39+
return super.extractParamsFromWeigthMap(weightMap)
3540
}
3641
}

src/tinyFaceDetector/const.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,4 @@ export const BOX_ANCHORS = [
1010
new Point(9.041765, 10.66308)
1111
]
1212

13-
export const MEAN_RGB: [number, number, number] = [117.001, 114.697, 97.404]
14-
15-
export const DEFAULT_MODEL_NAME = 'tiny_face_detector_model'
13+
export const MEAN_RGB: [number, number, number] = [117.001, 114.697, 97.404]

src/tinyYolov2/TinyYolov2.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { Point, TNetInput } from 'tfjs-image-recognition-base';
2-
import { ITinyYolov2Options, TinyYolov2 as TinyYolov2Base } from 'tfjs-tiny-yolov2';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { ParamMapping, Point, TNetInput } from 'tfjs-image-recognition-base';
3+
import { ITinyYolov2Options, TinyYolov2 as TinyYolov2Base, TinyYolov2NetParams } from 'tfjs-tiny-yolov2';
34

45
import { FaceDetection } from '../classes';
56
import {
@@ -46,8 +47,11 @@ export class TinyYolov2 extends TinyYolov2Base {
4647
return objectDetections.map(det => new FaceDetection(det.score, det.relativeBox, { width: det.imageWidth, height: det.imageHeight }))
4748
}
4849

49-
protected loadQuantizedParams(modelUri: string | undefined) {
50-
const defaultModelName = this.withSeparableConvs ? DEFAULT_MODEL_NAME_SEPARABLE_CONV : DEFAULT_MODEL_NAME
51-
return super.loadQuantizedParams(modelUri, defaultModelName) as any
50+
protected getDefaultModelName(): string {
51+
return this.withSeparableConvs ? DEFAULT_MODEL_NAME_SEPARABLE_CONV : DEFAULT_MODEL_NAME
52+
}
53+
54+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
55+
return super.extractParamsFromWeigthMap(weightMap)
5256
}
5357
}

0 commit comments

Comments
 (0)