Skip to content

Commit 43bf688

Browse files
move shared logic for nets working on face images to FaceProcessor
1 parent b3bcb3e commit 43bf688

25 files changed

+187
-202
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
2+
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
3+
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
4+
5+
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
6+
7+
constructor(faceFeatureExtractor: FaceFeatureExtractor) {
8+
super('FaceExpressionNet', faceFeatureExtractor)
9+
}
10+
11+
public dispose(throwOnRedispose: boolean = true) {
12+
this.faceFeatureExtractor.dispose(throwOnRedispose)
13+
super.dispose(throwOnRedispose)
14+
}
15+
16+
protected getDefaultModelName(): string {
17+
return 'face_expression_model'
18+
}
19+
20+
protected getClassifierChannelsIn(): number {
21+
return 256
22+
}
23+
24+
protected getClassifierChannelsOut(): number {
25+
return 7
26+
}
27+
}

src/faceExpressionNet/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export * from './FaceExpressionNet';

src/faceFeatureExtractor/FaceFeatureExtractor.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParams } from './extractParams';
77
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
8-
import { DenseBlock4Params, NetParams } from './types';
8+
import { DenseBlock4Params, IFaceFeatureExtractor, FaceFeatureExtractorParams } from './types';
99

1010
function denseBlock(
1111
x: tf.Tensor4D,
@@ -33,7 +33,7 @@ function denseBlock(
3333
})
3434
}
3535

36-
export class FaceFeatureExtractor extends NeuralNetwork<NetParams> {
36+
export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> {
3737

3838
constructor() {
3939
super('FaceFeatureExtractor')

src/faceFeatureExtractor/TinyFaceFeatureExtractor.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
55
import { depthwiseSeparableConv } from './depthwiseSeparableConv';
66
import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny';
77
import { extractParamsTiny } from './extractParamsTiny';
8-
import { DenseBlock3Params, TinyNetParams } from './types';
8+
import { DenseBlock3Params, IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types';
99

1010
function denseBlock(
1111
x: tf.Tensor4D,
@@ -30,7 +30,7 @@ function denseBlock(
3030
})
3131
}
3232

33-
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyNetParams> {
33+
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> {
3434

3535
constructor() {
3636
super('TinyFaceFeatureExtractor')

src/faceFeatureExtractor/extractParams.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base';
22

33
import { extractorsFactory } from './extractorsFactory';
4-
import { NetParams } from './types';
4+
import { FaceFeatureExtractorParams } from './types';
55

6-
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
6+
export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
77

88
const paramMappings: ParamMapping[] = []
99

src/faceFeatureExtractor/extractParamsFromWeigthMap.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
22
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
33

44
import { loadParamsFactory } from './loadParamsFactory';
5-
import { NetParams } from './types';
5+
import { FaceFeatureExtractorParams } from './types';
66

77
export function extractParamsFromWeigthMap(
88
weightMap: tf.NamedTensorMap
9-
): { params: NetParams, paramMappings: ParamMapping[] } {
9+
): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
1010

1111
const paramMappings: ParamMapping[] = []
1212

src/faceFeatureExtractor/extractParamsFromWeigthMapTiny.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
22
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
33

44
import { loadParamsFactory } from './loadParamsFactory';
5-
import { TinyNetParams } from './types';
5+
import { TinyFaceFeatureExtractorParams } from './types';
66

77
export function extractParamsFromWeigthMapTiny(
88
weightMap: tf.NamedTensorMap
9-
): { params: TinyNetParams, paramMappings: ParamMapping[] } {
9+
): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
1010

1111
const paramMappings: ParamMapping[] = []
1212

src/faceFeatureExtractor/extractParamsTiny.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base';
22

33
import { extractorsFactory } from './extractorsFactory';
4-
import { TinyNetParams } from './types';
4+
import { TinyFaceFeatureExtractorParams } from './types';
55

6-
export function extractParamsTiny(weights: Float32Array): { params: TinyNetParams, paramMappings: ParamMapping[] } {
6+
export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
77

88
const paramMappings: ParamMapping[] = []
99

src/faceFeatureExtractor/types.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as tf from '@tensorflow/tfjs-core';
2+
import { NetInput, NeuralNetwork } from 'tfjs-image-recognition-base';
23
import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
34

45
export type ConvWithBatchNormParams = BatchNormParams & {
@@ -27,16 +28,19 @@ export type DenseBlock4Params = DenseBlock3Params & {
2728
conv3: SeparableConvParams
2829
}
2930

30-
export type TinyNetParams = {
31+
export type TinyFaceFeatureExtractorParams = {
3132
dense0: DenseBlock3Params
3233
dense1: DenseBlock3Params
3334
dense2: DenseBlock3Params
3435
}
3536

36-
export type NetParams = {
37+
export type FaceFeatureExtractorParams = {
3738
dense0: DenseBlock4Params
3839
dense1: DenseBlock4Params
3940
dense2: DenseBlock4Params
4041
dense3: DenseBlock4Params
4142
}
4243

44+
export interface IFaceFeatureExtractor<TNetParams extends TinyFaceFeatureExtractorParams | FaceFeatureExtractorParams> extends NeuralNetwork<TNetParams> {
45+
forward(input: NetInput): tf.Tensor4D
46+
}
Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,18 @@
1-
import * as tf from '@tensorflow/tfjs-core';
2-
import { NetInput } from 'tfjs-image-recognition-base';
3-
41
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
5-
import { extractParams } from './extractParams';
6-
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
2+
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
73
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
8-
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { NetParams } from './types';
10-
import { seperateWeightMaps } from './util';
11-
12-
export class FaceLandmark68Net extends FaceLandmark68NetBase<NetParams> {
13-
14-
private static classifierNumFilters: number = 256
15-
16-
private _faceFeatureExtractor: FaceFeatureExtractor
17-
18-
19-
constructor(faceFeatureExtractor: FaceFeatureExtractor) {
20-
super('FaceLandmark68Net')
21-
this._faceFeatureExtractor = faceFeatureExtractor
22-
}
23-
24-
public get faceFeatureExtractor(): FaceFeatureExtractor {
25-
return this._faceFeatureExtractor
26-
}
27-
28-
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
29-
30-
const { params } = this
314

32-
if (!params) {
33-
throw new Error('FaceLandmark68Net - load model before inference')
34-
}
5+
export class FaceLandmark68Net extends FaceLandmark68NetBase<FaceFeatureExtractorParams> {
356

36-
if (!this.faceFeatureExtractor.isLoaded) {
37-
throw new Error('FaceLandmark68Net - load face feature extractor model before inference')
38-
}
39-
40-
return tf.tidy(() => {
41-
const bottleneckFeatures = input instanceof NetInput
42-
? this.faceFeatureExtractor.forward(input)
43-
: input
44-
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
45-
})
46-
}
47-
48-
public dispose(throwOnRedispose: boolean = true) {
49-
this.faceFeatureExtractor.dispose(throwOnRedispose)
50-
super.dispose(throwOnRedispose)
7+
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
8+
super('FaceLandmark68Net', faceFeatureExtractor)
519
}
5210

5311
protected getDefaultModelName(): string {
5412
return 'face_landmark_68_model'
5513
}
5614

57-
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
58-
59-
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
60-
61-
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
62-
63-
return extractParamsFromWeigthMap(classifierMap)
64-
}
65-
66-
protected extractParams(weights: Float32Array) {
67-
68-
const classifierWeightSize = 136 * FaceLandmark68Net.classifierNumFilters + 136
69-
70-
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
71-
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
72-
73-
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
74-
return extractParams(classifierWeights, FaceLandmark68Net.classifierNumFilters)
15+
protected getClassifierChannelsIn(): number {
16+
return 256
7517
}
7618
}

src/faceLandmarkNet/FaceLandmark68NetBase.ts

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
import * as tf from '@tensorflow/tfjs-core';
2-
import { IDimensions, isEven, NetInput, NeuralNetwork, Point, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
2+
import { IDimensions, isEven, NetInput, Point, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
33

44
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
5+
import { FaceFeatureExtractorParams, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
6+
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
57

6-
export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> {
7-
8-
// TODO: make super.name protected
9-
private __name: string
10-
11-
constructor(_name: string) {
12-
super(_name)
13-
this.__name = _name
14-
}
15-
16-
public abstract runNet(netInput: NetInput): tf.Tensor2D
8+
export abstract class FaceLandmark68NetBase<
9+
TExtractorParams extends FaceFeatureExtractorParams | TinyFaceFeatureExtractorParams
10+
>
11+
extends FaceProcessor<TExtractorParams> {
1712

1813
public postProcess(output: tf.Tensor2D, inputSize: number, originalDimensions: IDimensions[]): tf.Tensor2D {
1914

@@ -103,4 +98,8 @@ export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<Net
10398
? landmarksForBatch
10499
: landmarksForBatch[0]
105100
}
101+
102+
protected getClassifierChannelsOut(): number {
103+
return 136
104+
}
106105
}
Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,19 @@
1-
import * as tf from '@tensorflow/tfjs-core';
2-
import { NetInput } from 'tfjs-image-recognition-base';
1+
import { TinyFaceFeatureExtractorParams } from 'src/faceFeatureExtractor/types';
32

43
import { TinyFaceFeatureExtractor } from '../faceFeatureExtractor/TinyFaceFeatureExtractor';
5-
import { extractParams } from './extractParams';
6-
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
74
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
8-
import { fullyConnectedLayer } from './fullyConnectedLayer';
9-
import { NetParams } from './types';
10-
import { seperateWeightMaps } from './util';
115

12-
export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<NetParams> {
6+
export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<TinyFaceFeatureExtractorParams> {
137

14-
private static classifierNumFilters: number = 128
15-
16-
private _faceFeatureExtractor: TinyFaceFeatureExtractor
17-
18-
constructor(faceFeatureExtractor: TinyFaceFeatureExtractor) {
19-
super('FaceLandmark68TinyNet')
20-
this._faceFeatureExtractor = faceFeatureExtractor
21-
}
22-
23-
public get faceFeatureExtractor(): TinyFaceFeatureExtractor {
24-
return this._faceFeatureExtractor
25-
}
26-
27-
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
28-
29-
const { params } = this
30-
31-
if (!params) {
32-
throw new Error('FaceLandmark68TinyNet - load model before inference')
33-
}
34-
35-
if (!this.faceFeatureExtractor.isLoaded) {
36-
throw new Error('FaceLandmark68TinyNet - load face feature extractor model before inference')
37-
}
38-
39-
return tf.tidy(() => {
40-
const bottleneckFeatures = input instanceof NetInput
41-
? this.faceFeatureExtractor.forward(input)
42-
: input
43-
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
44-
})
45-
}
46-
47-
public dispose(throwOnRedispose: boolean = true) {
48-
this.faceFeatureExtractor.dispose(throwOnRedispose)
49-
super.dispose(throwOnRedispose)
8+
constructor(faceFeatureExtractor: TinyFaceFeatureExtractor = new TinyFaceFeatureExtractor()) {
9+
super('FaceLandmark68TinyNet', faceFeatureExtractor)
5010
}
5111

5212
protected getDefaultModelName(): string {
5313
return 'face_landmark_68_tiny_model'
5414
}
5515

56-
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
57-
58-
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
59-
60-
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
61-
62-
return extractParamsFromWeigthMap(classifierMap)
63-
}
64-
65-
protected extractParams(weights: Float32Array) {
66-
67-
const classifierWeightSize = 136 * FaceLandmark68TinyNet.classifierNumFilters + 136
68-
69-
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
70-
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
71-
72-
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
73-
return extractParams(classifierWeights, FaceLandmark68TinyNet.classifierNumFilters)
16+
protected getClassifierChannelsIn(): number {
17+
return 128
7418
}
7519
}

src/faceLandmarkNet/extractParams.ts

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)