Skip to content

Commit 9c30d35

Browse files
Merge pull request justadudewhohacks#60 from justadudewhohacks/fixes
fixed loadweightmap always fetching from relative url + bump tfjs-core version
2 parents 7c3c58f + 3eebacf commit 9c30d35

File tree

10 files changed

+124
-49
lines changed

10 files changed

+124
-49
lines changed

build/commons/loadWeightMap.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export declare function getModelUris(uri: string | undefined, defaultModelName: string): {
2-
manifestUri: string;
32
modelBaseUri: string;
3+
manifestUri: string;
44
};
55
export declare function loadWeightMap(uri: string | undefined, defaultModelName: string): Promise<any>;

build/commons/loadWeightMap.js

Lines changed: 22 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build/commons/loadWeightMap.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/face-api.js

Lines changed: 22 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/face-api.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/face-api.min.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/package-lock.json

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"author": "justadudewhohacks",
66
"license": "MIT",
77
"dependencies": {
8+
"@tensorflow/tfjs-core": "^0.12.7",
89
"express": "^4.16.3",
910
"request": "^2.87.0"
1011
}

src/commons/loadWeightMap.ts

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
export function getModelUris(uri: string | undefined, defaultModelName: string) {
4-
const parts = (uri || '').split('/')
5-
6-
const modelBaseUri = (
7-
(uri || '').endsWith('.json')
8-
? parts.slice(0, parts.length - 1)
9-
: parts
10-
).filter(s => s).join('/')
11-
124
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
13-
const manifestUri = !uri || !modelBaseUri
14-
? defaultManifestFilename
15-
: (
16-
uri.endsWith('.json')
17-
? uri
18-
: `${modelBaseUri}/${defaultManifestFilename}`
19-
)
20-
21-
return { manifestUri, modelBaseUri }
5+
6+
if (!uri) {
7+
return {
8+
modelBaseUri: '',
9+
manifestUri: defaultManifestFilename
10+
}
11+
}
12+
13+
if (uri === '/') {
14+
return {
15+
modelBaseUri: '/',
16+
manifestUri: `/${defaultManifestFilename}`
17+
}
18+
}
19+
20+
const parts = uri.split('/').filter(s => s)
21+
22+
const manifestFile = uri.endsWith('.json')
23+
? parts[parts.length - 1]
24+
: defaultManifestFilename
25+
26+
let modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
27+
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
28+
29+
return {
30+
modelBaseUri,
31+
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
32+
}
2233
}
2334

2435
export async function loadWeightMap(

test/tests/commons/loadWeightMap.test.ts

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
import { getModelUris } from '../../../src/commons/loadWeightMap';
22

3-
const FAKE_DEFAULT_MODEL_NAME = 'default_model_name'
3+
const FAKE_DEFAULT_MODEL_NAME = 'fake_model_name'
44

55
describe('loadWeightMap', () => {
66

77
describe('getModelUris', () => {
88

9-
it('returns uris from top level url if no argument passed', () => {
9+
it('returns uris from relative url if no argument passed', () => {
1010
const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME)
1111

1212
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
1313
expect(result.modelBaseUri).toEqual('')
1414
})
1515

16-
it('returns uris from top level url for empty string', () => {
16+
it('returns uris from relative url for empty string', () => {
1717
const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME)
1818

1919
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
2020
expect(result.modelBaseUri).toEqual('')
2121
})
2222

23-
it('returns uris for top level url', () => {
23+
it('returns uris for top level url, leading slash preserved', () => {
2424
const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME)
2525

26-
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
27-
expect(result.modelBaseUri).toEqual('')
26+
expect(result.manifestUri).toEqual(`/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
27+
expect(result.modelBaseUri).toEqual('/')
2828
})
2929

3030
it('returns uris, given url path', () => {
@@ -35,8 +35,8 @@ describe('loadWeightMap', () => {
3535
expect(result.modelBaseUri).toEqual(uri)
3636
})
3737

38-
it('returns uris, given url path, leading slash', () => {
39-
const uri = 'path/to/modelfiles'
38+
it('returns uris, given url path, leading slash preserved', () => {
39+
const uri = '/path/to/modelfiles'
4040
const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME)
4141

4242
expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
@@ -51,6 +51,14 @@ describe('loadWeightMap', () => {
5151
expect(result.modelBaseUri).toEqual('path/to/modelfiles')
5252
})
5353

54+
it('returns uris, given manifest uri, leading slash preserved', () => {
55+
const uri = '/path/to/modelfiles/model-weights_manifest.json'
56+
const result = getModelUris(uri, FAKE_DEFAULT_MODEL_NAME)
57+
58+
expect(result.manifestUri).toEqual(uri)
59+
expect(result.modelBaseUri).toEqual('/path/to/modelfiles')
60+
})
61+
5462
})
5563

5664
})

0 commit comments

Comments
 (0)