Skip to content

Commit 9af582a

Browse files
use env wrapper for browser related code
1 parent 3ed6b10 commit 9af582a

16 files changed

+84
-77
lines changed

src/dom/drawLandmarks.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { getContext2dOrThrow, getDefaultDrawOptions, resolveInput } from 'tfjs-image-recognition-base';
1+
import { env, getContext2dOrThrow, getDefaultDrawOptions, resolveInput } from 'tfjs-image-recognition-base';
22

33
import { FaceLandmarks } from '../classes/FaceLandmarks';
44
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
@@ -11,7 +11,7 @@ export function drawLandmarks(
1111
options?: DrawLandmarksOptions
1212
) {
1313
const canvas = resolveInput(canvasArg)
14-
if (!(canvas instanceof HTMLCanvasElement)) {
14+
if (!(canvas instanceof env.getEnv().Canvas)) {
1515
throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement')
1616
}
1717

src/dom/extractFaces.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
} from 'tfjs-image-recognition-base';
99

1010
import { FaceDetection } from '../classes/FaceDetection';
11+
import { env } from 'tfjs-image-recognition-base';
1112

1213
/**
1314
* Extracts the image regions containing the detected faces.
@@ -21,17 +22,19 @@ export async function extractFaces(
2122
detections: Array<FaceDetection | Rect>
2223
): Promise<HTMLCanvasElement[]> {
2324

25+
const { Canvas } = env.getEnv()
26+
2427
let canvas = input as HTMLCanvasElement
2528

26-
if (!(input instanceof HTMLCanvasElement)) {
29+
if (!(input instanceof Canvas)) {
2730
const netInput = await toNetInput(input)
2831

2932
if (netInput.batchSize > 1) {
3033
throw new Error('extractFaces - batchSize > 1 not supported')
3134
}
3235

3336
const tensorOrCanvas = netInput.getInput(0)
34-
canvas = tensorOrCanvas instanceof HTMLCanvasElement
37+
canvas = tensorOrCanvas instanceof Canvas
3538
? tensorOrCanvas
3639
: await imageTensorToCanvas(tensorOrCanvas)
3740
}

src/faceLandmarkNet/FaceLandmark68Net.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
44

55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParams } from './extractParams';
7+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
78
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
89
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { loadQuantizedParams } from './loadQuantizedParams';
1010
import { DenseBlock4Params, NetParams } from './types';
1111

1212
function denseBlock(
@@ -64,10 +64,13 @@ export class FaceLandmark68Net extends FaceLandmark68NetBase<NetParams> {
6464
})
6565
}
6666

67-
protected loadQuantizedParams(uri: string | undefined) {
68-
return loadQuantizedParams(uri)
67+
protected getDefaultModelName(): string {
68+
return 'face_landmark_68_model'
6969
}
7070

71+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
72+
return extractParamsFromWeigthMap(weightMap)
73+
}
7174

7275
protected extractParams(weights: Float32Array) {
7376
return extractParams(weights)

src/faceLandmarkNet/FaceLandmark68NetBase.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { IDimensions, isEven, NetInput, NeuralNetwork, Point, TNetInput, toNetIn
33

44
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
55

6-
export class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
6+
export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
77

88
// TODO: make super.name protected
99
private __name: string
@@ -13,9 +13,7 @@ export class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
1313
this.__name = _name
1414
}
1515

16-
public runNet(_: NetInput): tf.Tensor2D {
17-
throw new Error(`${this.__name} - runNet not implemented`)
18-
}
16+
public abstract runNet(netInput: NetInput): tf.Tensor2D
1917

2018
public postProcess(output: tf.Tensor2D, inputSize: number, originalDimensions: IDimensions[]): tf.Tensor2D {
2119

src/faceLandmarkNet/FaceLandmark68TinyNet.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParamsTiny } from './extractParamsTiny';
77
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
88
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { loadQuantizedParamsTiny } from './loadQuantizedParamsTiny';
9+
import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny';
1010
import { DenseBlock3Params, TinyNetParams } from './types';
1111

1212
function denseBlock(
@@ -60,8 +60,12 @@ export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<TinyNetParams>
6060
})
6161
}
6262

63-
protected loadQuantizedParams(uri: string | undefined) {
64-
return loadQuantizedParamsTiny(uri)
63+
protected getDefaultModelName(): string {
64+
return 'face_landmark_68_tiny_model'
65+
}
66+
67+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
68+
return extractParamsFromWeigthMapTiny(weightMap)
6569
}
6670

6771
protected extractParams(weights: Float32Array) {

src/faceLandmarkNet/loadQuantizedParams.ts renamed to src/faceLandmarkNet/extractParamsFromWeigthMap.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import { disposeUnusedWeightTensors, loadWeightMap, ParamMapping } from 'tfjs-image-recognition-base';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
23

34
import { loadParamsFactory } from './loadParamsFactory';
45
import { NetParams } from './types';
56

6-
const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
7+
export function extractParamsFromWeigthMap(
8+
weightMap: tf.NamedTensorMap
9+
): { params: NetParams, paramMappings: ParamMapping[] } {
710

8-
export async function loadQuantizedParams(
9-
uri: string | undefined
10-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
11-
12-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
1311
const paramMappings: ParamMapping[] = []
1412

1513
const {

src/faceLandmarkNet/loadQuantizedParamsTiny.ts renamed to src/faceLandmarkNet/extractParamsFromWeigthMapTiny.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import { disposeUnusedWeightTensors, loadWeightMap, ParamMapping } from 'tfjs-image-recognition-base';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
23

34
import { loadParamsFactory } from './loadParamsFactory';
45
import { TinyNetParams } from './types';
56

6-
const DEFAULT_MODEL_NAME = 'face_landmark_68_tiny_model'
7+
export function extractParamsFromWeigthMapTiny(
8+
weightMap: tf.NamedTensorMap
9+
): { params: TinyNetParams, paramMappings: ParamMapping[] } {
710

8-
export async function loadQuantizedParamsTiny(
9-
uri: string | undefined
10-
): Promise<{ params: TinyNetParams, paramMappings: ParamMapping[] }> {
11-
12-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
1311
const paramMappings: ParamMapping[] = []
1412

1513
const {

src/faceRecognitionNet/FaceRecognitionNet.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-
33

44
import { convDown } from './convLayer';
55
import { extractParams } from './extractParams';
6-
import { loadQuantizedParams } from './loadQuantizedParams';
6+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
77
import { residual, residualDown } from './residualLayer';
88
import { NetParams } from './types';
99

@@ -78,8 +78,12 @@ export class FaceRecognitionNet extends NeuralNetwork<NetParams> {
7878
: faceDescriptorsForBatch[0]
7979
}
8080

81-
protected loadQuantizedParams(uri: string | undefined) {
82-
return loadQuantizedParams(uri)
81+
protected getDefaultModelName(): string {
82+
return 'face_recognition_model'
83+
}
84+
85+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
86+
return extractParamsFromWeigthMap(weightMap)
8387
}
8488

8589
protected extractParams(weights: Float32Array) {

src/faceRecognitionNet/loadQuantizedParams.ts renamed to src/faceRecognitionNet/extractParamsFromWeigthMap.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import {
99

1010
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
1111

12-
const DEFAULT_MODEL_NAME = 'face_recognition_model'
13-
1412
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
1513

1614
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
@@ -46,11 +44,10 @@ function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
4644

4745
}
4846

49-
export async function loadQuantizedParams(
50-
uri: string | undefined
51-
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
47+
export function extractParamsFromWeigthMap(
48+
weightMap: tf.NamedTensorMap
49+
): { params: NetParams, paramMappings: ParamMapping[] } {
5250

53-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
5451
const paramMappings: ParamMapping[] = []
5552

5653
const {

src/mtcnn/Mtcnn.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import { FaceLandmarks5 } from '../classes/FaceLandmarks5';
77
import { bgrToRgbTensor } from './bgrToRgbTensor';
88
import { CELL_SIZE } from './config';
99
import { extractParams } from './extractParams';
10+
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
1011
import { getSizesForScale } from './getSizesForScale';
11-
import { loadQuantizedParams } from './loadQuantizedParams';
1212
import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions';
1313
import { pyramidDown } from './pyramidDown';
1414
import { stage1 } from './stage1';
@@ -146,9 +146,12 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
146146
)
147147
}
148148

149-
// none of the param tensors are quantized yet
150-
protected loadQuantizedParams(uri: string | undefined) {
151-
return loadQuantizedParams(uri)
149+
protected getDefaultModelName(): string {
150+
return 'mtcnn_model'
151+
}
152+
153+
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
154+
return extractParamsFromWeigthMap(weightMap)
152155
}
153156

154157
protected extractParams(weights: Float32Array) {

0 commit comments

Comments
 (0)