Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5aa35a3

Browse files
authored
[io] Add field user-defined metadata; Tweak API for tf.io.fromMemory() (#1864)
FEATURE - This is the first PR for adding support for user-defined metadata in model artifacts. Design doc has been circulated and discussed. - Add the field `userDefinedMetadata` to `ModelArtifacts` and `ModelJSON`. - Deprecate the old API of `tf.io.fromMemory()` which consisted of multiple arguments. The arguments are consolidated into on in the new API. - Add unit tests. Towards tensorflow/tfjs#1596
1 parent 4432e82 commit 5aa35a3

File tree

3 files changed

+89
-29
lines changed

3 files changed

+89
-29
lines changed

‎src/io/passthrough.ts‎

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,10 @@
2222
import {IOHandler, ModelArtifacts, SaveResult, TrainingConfig, WeightsManifestEntry} from './types';
2323

2424
class PassthroughLoader implements IOHandler {
25-
constructor(
26-
private readonly modelTopology?: {}|ArrayBuffer,
27-
private readonly weightSpecs?: WeightsManifestEntry[],
28-
private readonly weightData?: ArrayBuffer,
29-
private readonly trainingConfig?: TrainingConfig) {}
25+
constructor(private readonly modelArtifacts?: ModelArtifacts) {}
3026

3127
async load(): Promise<ModelArtifacts> {
32-
let result = {};
33-
if (this.modelTopology != null) {
34-
result = {modelTopology: this.modelTopology, ...result};
35-
}
36-
if (this.weightSpecs != null && this.weightSpecs.length > 0) {
37-
result = {weightSpecs: this.weightSpecs, ...result};
38-
}
39-
if (this.weightData != null && this.weightData.byteLength > 0) {
40-
result = {weightData: this.weightData, ...result};
41-
}
42-
if (this.trainingConfig != null) {
43-
result = {trainingConfig: this.trainingConfig, ...result};
44-
}
45-
return result;
28+
return this.modelArtifacts;
4629
}
4730
}
4831

@@ -67,7 +50,7 @@ class PassthroughSaver implements IOHandler {
6750
* modelTopology, weightSpecs, weightData));
6851
* ```
6952
*
70-
* @param modelTopology a object containing model topology (i.e., parsed from
53+
* @param modelArtifacts a object containing model topology (i.e., parsed from
7154
* the JSON format).
7255
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the
7356
* names, shapes, types, and quantization of the weight data.
@@ -78,13 +61,39 @@ class PassthroughSaver implements IOHandler {
7861
* @returns A passthrough `IOHandler` that simply loads the provided data.
7962
*/
8063
export function fromMemory(
81-
modelTopology: {}, weightSpecs?: WeightsManifestEntry[],
64+
modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[],
8265
weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler {
83-
// TODO(cais): The arguments should probably be consolidated into a single
84-
// object, with proper deprecation process. Even though this function isn't
85-
// documented, it is public and being used by some downstream libraries.
86-
return new PassthroughLoader(
87-
modelTopology, weightSpecs, weightData, trainingConfig);
66+
if (arguments.length === 1) {
67+
const isModelArtifacts =
68+
(modelArtifacts as ModelArtifacts).modelTopology != null ||
69+
(modelArtifacts as ModelArtifacts).weightSpecs != null;
70+
if (isModelArtifacts) {
71+
return new PassthroughLoader(modelArtifacts as ModelArtifacts);
72+
} else {
73+
// Legacy support: with only modelTopology.
74+
// TODO(cais): Remove this deprecated API.
75+
console.warn(
76+
'Please call tf.io.fromMemory() with only one argument. ' +
77+
'The argument should be of type ModelArtifacts. ' +
78+
'The multi-argument signature of tf.io.fromMemory() has been ' +
79+
'deprecated and will be removed in a future release.');
80+
return new PassthroughLoader({modelTopology: modelArtifacts as {}});
81+
}
82+
} else {
83+
// Legacy support.
84+
// TODO(cais): Remove this deprecated API.
85+
console.warn(
86+
'Please call tf.io.fromMemory() with only one argument. ' +
87+
'The argument should be of type ModelArtifacts. ' +
88+
'The multi-argument signature of tf.io.fromMemory() has been ' +
89+
'deprecated and will be removed in a future release.');
90+
return new PassthroughLoader({
91+
modelTopology: modelArtifacts as {},
92+
weightSpecs,
93+
weightData,
94+
trainingConfig
95+
});
96+
}
8897
}
8998

9099
/**

‎src/io/passthrough_test.ts‎

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,61 @@ describeWithFlags('Passthrough Saver', BROWSER_ENVS, () => {
114114
});
115115

116116
describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
117+
it('load topology and weights: legacy signature', async () => {
118+
const passthroughHandler = tf.io.fromMemory(
119+
modelTopology1, weightSpecs1, weightData1);
120+
const modelArtifacts = await passthroughHandler.load();
121+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122+
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
123+
expect(modelArtifacts.weightData).toEqual(weightData1);
124+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
125+
});
126+
117127
it('load topology and weights', async () => {
118-
const passthroughHandler =
119-
tf.io.fromMemory(modelTopology1, weightSpecs1, weightData1);
128+
const passthroughHandler = tf.io.fromMemory({
129+
modelTopology: modelTopology1,
130+
weightSpecs: weightSpecs1,
131+
weightData: weightData1
132+
});
120133
const modelArtifacts = await passthroughHandler.load();
121134
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122135
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
123136
expect(modelArtifacts.weightData).toEqual(weightData1);
137+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
124138
});
125139

126-
it('load model topology only', async () => {
140+
it('load model topology only: legacy signature', async () => {
127141
const passthroughHandler = tf.io.fromMemory(modelTopology1);
128142
const modelArtifacts = await passthroughHandler.load();
129143
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
130144
expect(modelArtifacts.weightSpecs).toEqual(undefined);
131145
expect(modelArtifacts.weightData).toEqual(undefined);
146+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
147+
});
148+
149+
it('load model topology only', async () => {
150+
const passthroughHandler = tf.io.fromMemory({
151+
modelTopology: modelTopology1
152+
});
153+
const modelArtifacts = await passthroughHandler.load();
154+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
155+
expect(modelArtifacts.weightSpecs).toEqual(undefined);
156+
expect(modelArtifacts.weightData).toEqual(undefined);
157+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
158+
});
159+
160+
it('load topology, weights, and user-defined metadata', async () => {
161+
const userDefinedMetadata: {} = {'fooField': 'fooValue'};
162+
const passthroughHandler = tf.io.fromMemory({
163+
modelTopology: modelTopology1,
164+
weightSpecs: weightSpecs1,
165+
weightData: weightData1,
166+
userDefinedMetadata
167+
});
168+
const modelArtifacts = await passthroughHandler.load();
169+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
170+
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
171+
expect(modelArtifacts.weightData).toEqual(weightData1);
172+
expect(modelArtifacts.userDefinedMetadata).toEqual(userDefinedMetadata);
132173
});
133174
});

‎src/io/types.ts‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ export declare interface ModelArtifacts {
275275
* `tf.LayersModel` instance.)
276276
*/
277277
convertedBy?: string|null;
278+
279+
/**
280+
* User-defined metadata about the model.
281+
*/
282+
userDefinedMetadata?: {};
278283
}
279284

280285
/**
@@ -330,6 +335,11 @@ export declare interface ModelJSON {
330335
* `tf.LayersModel` instance.)
331336
*/
332337
convertedBy?: string|null;
338+
339+
/**
340+
* User-defined metadata about the model.
341+
*/
342+
userDefinedMetadata?: {};
333343
}
334344

335345
/**

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /