Skip to content

Commit 656602f

Browse files
fixed loadweightmap always fetching from relative url
1 parent 7c3c58f commit 656602f

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

src/commons/loadWeightMap.ts

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
export function getModelUris(uri: string | undefined, defaultModelName: string) {
4-
const parts = (uri || '').split('/')
5-
6-
const modelBaseUri = (
7-
(uri || '').endsWith('.json')
8-
? parts.slice(0, parts.length - 1)
9-
: parts
10-
).filter(s => s).join('/')
11-
124
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
13-
const manifestUri = !uri || !modelBaseUri
14-
? defaultManifestFilename
15-
: (
16-
uri.endsWith('.json')
17-
? uri
18-
: `${modelBaseUri}/${defaultManifestFilename}`
19-
)
20-
21-
return { manifestUri, modelBaseUri }
5+
6+
if (!uri) {
7+
return {
8+
modelBaseUri: '',
9+
manifestUri: defaultManifestFilename
10+
}
11+
}
12+
13+
if (uri === '/') {
14+
return {
15+
modelBaseUri: '/',
16+
manifestUri: `/${defaultManifestFilename}`
17+
}
18+
}
19+
20+
const parts = uri.split('/').filter(s => s)
21+
22+
const manifestFile = uri.endsWith('.json')
23+
? parts[parts.length - 1]
24+
: defaultManifestFilename
25+
26+
let modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
27+
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
28+
29+
return {
30+
modelBaseUri,
31+
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
32+
}
2233
}
2334

2435
export async function loadWeightMap(

test/tests/commons/loadWeightMap.test.ts

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
import { getModelUris } from '../../../src/commons/loadWeightMap';
22

3-
const FAKE_DEFAULT_MODEL_NAME = 'default_model_name'
3+
const FAKE_DEFAULT_MODEL_NAME = 'fake_model_name'
44

55
describe('loadWeightMap', () => {
66

77
describe('getModelUris', () => {
88

9-
it('returns uris from top level url if no argument passed', () => {
9+
it('returns uris from relative url if no argument passed', () => {
1010
const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME)
1111

1212
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
1313
expect(result.modelBaseUri).toEqual('')
1414
})
1515

16-
it('returns uris from top level url for empty string', () => {
16+
it('returns uris from relative url for empty string', () => {
1717
const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME)
1818

1919
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
2020
expect(result.modelBaseUri).toEqual('')
2121
})
2222

23-
it('returns uris for top level url', () => {
23+
it('returns uris for top level url, leading slash preserved', () => {
2424
const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME)
2525

26-
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
27-
expect(result.modelBaseUri).toEqual('')
26+
expect(result.manifestUri).toEqual(`/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
27+
expect(result.modelBaseUri).toEqual('/')
2828
})
2929

3030
it('returns uris, given url path', () => {
@@ -35,8 +35,8 @@ describe('loadWeightMap', () => {
3535
expect(result.modelBaseUri).toEqual(uri)
3636
})
3737

38-
it('returns uris, given url path, leading slash', () => {
39-
const uri = 'path/to/modelfiles'
38+
it('returns uris, given url path, leading slash preserved', () => {
39+
const uri = '/path/to/modelfiles'
4040
const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME)
4141

4242
expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
@@ -51,6 +51,14 @@ describe('loadWeightMap', () => {
5151
expect(result.modelBaseUri).toEqual('path/to/modelfiles')
5252
})
5353

54+
it('returns uris, given manifest uri, leading slash preserved', () => {
55+
const uri = '/path/to/modelfiles/model-weights_manifest.json'
56+
const result = getModelUris(uri, FAKE_DEFAULT_MODEL_NAME)
57+
58+
expect(result.manifestUri).toEqual(uri)
59+
expect(result.modelBaseUri).toEqual('/path/to/modelfiles')
60+
})
61+
5462
})
5563

5664
})

0 commit comments

Comments
 (0)