Skip to content

Commit 7c7cf2a

Browse files
PredictAgeAndGenderTask
1 parent 4205fc2 commit 7c7cf2a

10 files changed

+320
-54
lines changed

src/ageGenderNet/AgeGenderNet.ts

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { seperateWeightMaps } from '../faceProcessor/util';
66
import { TinyXception } from '../xception/TinyXception';
77
import { extractParams } from './extractParams';
88
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
9-
import { NetOutput, NetParams } from './types';
9+
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
1010

1111
export class AgeGenderNet extends NeuralNetwork<NetParams> {
1212

@@ -50,17 +50,32 @@ export class AgeGenderNet extends NeuralNetwork<NetParams> {
5050
return this.forwardInput(await toNetInput(input))
5151
}
5252

53-
public async predictAgeAndGender(input: TNetInput): Promise<{ age: number, gender: string, genderProbability: number }> {
53+
public async predictAgeAndGender(input: TNetInput): Promise<AgeAndGenderPrediction | AgeAndGenderPrediction[]> {
5454
const netInput = await toNetInput(input)
5555
const out = await this.forwardInput(netInput)
56-
const age = (await out.age.data())[0]
57-
const probMale = (await out.gender.data())[0]
5856

59-
const isMale = probMale > 0.5
60-
const gender = isMale ? 'male' : 'female'
61-
const genderProbability = isMale ? probMale : (1 - probMale)
62-
63-
return { age, gender, genderProbability }
57+
const ages = tf.unstack(out.age)
58+
const genders = tf.unstack(out.gender)
59+
const ageAndGenderTensors = ages.map((ageTensor, i) => ({
60+
ageTensor,
61+
genderTensor: genders[i]
62+
}))
63+
64+
const predictionsByBatch = await Promise.all(
65+
ageAndGenderTensors.map(async ({ ageTensor, genderTensor }) => {
66+
const age = (await ageTensor.data())[0]
67+
const probMale = (await out.gender.data())[0]
68+
const isMale = probMale > 0.5
69+
const gender = isMale ? Gender.MALE : Gender.FEMALE
70+
const genderProbability = isMale ? probMale : (1 - probMale)
71+
72+
return { age, gender, genderProbability }
73+
})
74+
)
75+
76+
return netInput.isBatchInput
77+
? predictionsByBatch
78+
: predictionsByBatch[0]
6479
}
6580

6681
protected getDefaultModelName(): string {

src/ageGenderNet/types.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import * as tf from '@tensorflow/tfjs-core';
22
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
33

4+
export type AgeAndGenderPrediction = {
5+
age: number
6+
gender: Gender
7+
genderProbability: number
8+
}
9+
10+
export enum Gender {
11+
FEMALE = 'female',
12+
MALE = 'male'
13+
}
14+
415
export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
516

617
export type NetParams = {

src/factories/WithAge.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export type WithAge<TSource> = TSource & {
2+
age: number
3+
}
4+
5+
export function isWithAge(obj: any): obj is WithAge<{}> {
6+
return typeof obj['age'] === 'number'
7+
}
8+
9+
export function extendWithAge<
10+
TSource
11+
> (
12+
sourceObj: TSource,
13+
age: number
14+
): WithAge<TSource> {
15+
16+
const extension = { age }
17+
return Object.assign({}, sourceObj, extension)
18+
}

src/factories/WithGender.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { isValidProbablitiy } from 'tfjs-image-recognition-base';
2+
3+
import { Gender } from '../ageGenderNet/types';
4+
5+
export type WithGender<TSource> = TSource & {
6+
gender: Gender
7+
genderProbability: number
8+
}
9+
10+
export function isWithGender(obj: any): obj is WithGender<{}> {
11+
return (obj['gender'] === Gender.MALE || obj['gender'] === Gender.FEMALE)
12+
&& isValidProbablitiy(obj['genderProbability'])
13+
}
14+
15+
export function extendWithGender<
16+
TSource
17+
> (
18+
sourceObj: TSource,
19+
gender: Gender,
20+
genderProbability: number
21+
): WithGender<TSource> {
22+
23+
const extension = { gender, genderProbability }
24+
return Object.assign({}, sourceObj, extension)
25+
}
Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1-
import * as tf from '@tensorflow/tfjs-core';
21
import { TNetInput } from 'tfjs-image-recognition-base';
32

4-
import { extractFaces, extractFaceTensors } from '../dom';
53
import { extendWithFaceDescriptor, WithFaceDescriptor } from '../factories/WithFaceDescriptor';
64
import { WithFaceDetection } from '../factories/WithFaceDetection';
75
import { WithFaceLandmarks } from '../factories/WithFaceLandmarks';
86
import { ComposableTask } from './ComposableTask';
7+
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
98
import { nets } from './nets';
9+
import {
10+
PredictAllAgeAndGenderWithFaceAlignmentTask,
11+
PredictSingleAgeAndGenderWithFaceAlignmentTask,
12+
} from './PredictAgeAndGenderTask';
13+
import {
14+
PredictAllFaceExpressionsWithFaceAlignmentTask,
15+
PredictSingleFaceExpressionsWithFaceAlignmentTask,
16+
} from './PredictFaceExpressionsTask';
1017

1118
export class ComputeFaceDescriptorsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
1219
constructor(
@@ -25,19 +32,25 @@ export class ComputeAllFaceDescriptorsTask<
2532

2633
const parentResults = await this.parentTask
2734

28-
const dlibAlignedRects = parentResults.map(({ landmarks }) => landmarks.align(null, { useDlibAlignment: true }))
29-
const dlibAlignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
30-
? await extractFaceTensors(this.input, dlibAlignedRects)
31-
: await extractFaces(this.input, dlibAlignedRects)
35+
const descriptors = await extractAllFacesAndComputeResults<TSource, Float32Array[]>(
36+
parentResults,
37+
this.input,
38+
faces => Promise.all(faces.map(face =>
39+
nets.faceRecognitionNet.computeFaceDescriptor(face) as Promise<Float32Array>
40+
)),
41+
null,
42+
parentResult => parentResult.landmarks.align(null, { useDlibAlignment: true })
43+
)
3244

33-
const results = await Promise.all(parentResults.map(async (parentResult, i) => {
34-
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(dlibAlignedFaces[i]) as Float32Array
35-
return extendWithFaceDescriptor<TSource>(parentResult, descriptor)
36-
}))
45+
return descriptors.map((descriptor, i) => extendWithFaceDescriptor<TSource>(parentResults[i], descriptor))
46+
}
3747

38-
dlibAlignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
48+
withFaceExpressions(): PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
49+
return new PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
50+
}
3951

40-
return results
52+
withAgeAndGender(): PredictAllAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
53+
return new PredictAllAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
4154
}
4255
}
4356

@@ -51,15 +64,22 @@ export class ComputeSingleFaceDescriptorTask<
5164
if (!parentResult) {
5265
return
5366
}
67+
const descriptor = await extractSingleFaceAndComputeResult<TSource, Float32Array>(
68+
parentResult,
69+
this.input,
70+
face => nets.faceRecognitionNet.computeFaceDescriptor(face) as Promise<Float32Array>,
71+
null,
72+
parentResult => parentResult.landmarks.align(null, { useDlibAlignment: true })
73+
)
5474

55-
const dlibAlignedRect = parentResult.landmarks.align(null, { useDlibAlignment: true })
56-
const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
57-
? await extractFaceTensors(this.input, [dlibAlignedRect])
58-
: await extractFaces(this.input, [dlibAlignedRect])
59-
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[0]) as Float32Array
75+
return extendWithFaceDescriptor(parentResult, descriptor)
76+
}
6077

61-
alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
78+
withFaceExpressions(): PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
79+
return new PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
80+
}
6281

63-
return extendWithFaceDescriptor(parentResult, descriptor)
82+
withAgeAndGender(): PredictSingleAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
83+
return new PredictSingleAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
6484
}
6585
}

src/globalApi/DetectFaceLandmarksTasks.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac
1010
import { ComposableTask } from './ComposableTask';
1111
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
1212
import { nets } from './nets';
13+
import {
14+
PredictAllAgeAndGenderWithFaceAlignmentTask,
15+
PredictSingleAgeAndGenderWithFaceAlignmentTask,
16+
} from './PredictAgeAndGenderTask';
1317
import {
1418
PredictAllFaceExpressionsWithFaceAlignmentTask,
1519
PredictSingleFaceExpressionsWithFaceAlignmentTask,
@@ -59,6 +63,10 @@ export class DetectAllFaceLandmarksTask<
5963
return new PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
6064
}
6165

66+
withAgeAndGender(): PredictAllAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
67+
return new PredictAllAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
68+
}
69+
6270
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
6371
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
6472
}
@@ -91,6 +99,10 @@ export class DetectSingleFaceLandmarksTask<
9199
return new PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
92100
}
93101

102+
withAgeAndGender(): PredictSingleAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
103+
return new PredictSingleAgeAndGenderWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
104+
}
105+
94106
withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
95107
return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
96108
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
import { TNetInput } from 'tfjs-image-recognition-base';
3+
4+
import { AgeAndGenderPrediction } from '../ageGenderNet/types';
5+
import { extendWithAge, WithAge } from '../factories/WithAge';
6+
import { WithFaceDetection } from '../factories/WithFaceDetection';
7+
import { WithFaceLandmarks } from '../factories/WithFaceLandmarks';
8+
import { extendWithGender, WithGender } from '../factories/WithGender';
9+
import { ComposableTask } from './ComposableTask';
10+
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
11+
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
12+
import { nets } from './nets';
13+
import {
14+
PredictAllFaceExpressionsWithFaceAlignmentTask,
15+
PredictSingleFaceExpressionsWithFaceAlignmentTask,
16+
} from './PredictFaceExpressionsTask';
17+
18+
export class PredictAgeAndGenderTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
19+
constructor(
20+
protected parentTask: ComposableTask<TParentReturn> | Promise<TParentReturn>,
21+
protected input: TNetInput,
22+
protected extractedFaces?: Array<HTMLCanvasElement | tf.Tensor3D>
23+
) {
24+
super()
25+
}
26+
}
27+
28+
export class PredictAllAgeAndGenderTask<
29+
TSource extends WithFaceDetection<{}>
30+
> extends PredictAgeAndGenderTaskBase<WithAge<WithGender<TSource>>[], TSource[]> {
31+
32+
public async run(): Promise<WithAge<WithGender<TSource>>[]> {
33+
34+
const parentResults = await this.parentTask
35+
36+
const ageAndGenderByFace = await extractAllFacesAndComputeResults<TSource, AgeAndGenderPrediction[]>(
37+
parentResults,
38+
this.input,
39+
async faces => await Promise.all(faces.map(
40+
face => nets.ageGenderNet.predictAgeAndGender(face) as Promise<AgeAndGenderPrediction>
41+
)),
42+
this.extractedFaces
43+
)
44+
45+
return parentResults.map((parentResult, i) => {
46+
const { age, gender, genderProbability } = ageAndGenderByFace[i]
47+
return extendWithAge(extendWithGender(parentResult, gender, genderProbability), age)
48+
})
49+
}
50+
}
51+
52+
export class PredictSingleAgeAndGenderTask<
53+
TSource extends WithFaceDetection<{}>
54+
> extends PredictAgeAndGenderTaskBase<WithAge<WithGender<TSource>> | undefined, TSource | undefined> {
55+
56+
public async run(): Promise<WithAge<WithGender<TSource>> | undefined> {
57+
58+
const parentResult = await this.parentTask
59+
if (!parentResult) {
60+
return
61+
}
62+
63+
const { age, gender, genderProbability } = await extractSingleFaceAndComputeResult<TSource, AgeAndGenderPrediction>(
64+
parentResult,
65+
this.input,
66+
face => nets.ageGenderNet.predictAgeAndGender(face) as Promise<AgeAndGenderPrediction>,
67+
this.extractedFaces
68+
)
69+
70+
return extendWithAge(extendWithGender(parentResult, gender, genderProbability), age)
71+
}
72+
}
73+
74+
export class PredictAllAgeAndGenderWithFaceAlignmentTask<
75+
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
76+
> extends PredictAllAgeAndGenderTask<TSource> {
77+
78+
withFaceExpressions(): PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
79+
return new PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
80+
}
81+
82+
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
83+
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
84+
}
85+
}
86+
87+
export class PredictSingleAgeAndGenderWithFaceAlignmentTask<
88+
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
89+
> extends PredictSingleAgeAndGenderTask<TSource> {
90+
91+
withFaceExpressions(): PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
92+
return new PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
93+
}
94+
95+
withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
96+
return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
97+
}
98+
}

0 commit comments

Comments
 (0)