22
22
import { IOHandler , ModelArtifacts , SaveResult , TrainingConfig , WeightsManifestEntry } from './types' ;
23
23
24
24
class PassthroughLoader implements IOHandler {
25
- constructor (
26
- private readonly modelTopology ?: { } | ArrayBuffer ,
27
- private readonly weightSpecs ?: WeightsManifestEntry [ ] ,
28
- private readonly weightData ?: ArrayBuffer ,
29
- private readonly trainingConfig ?: TrainingConfig ) { }
25
+ constructor ( private readonly modelArtifacts ?: ModelArtifacts ) { }
30
26
31
27
async load ( ) : Promise < ModelArtifacts > {
32
- let result = { } ;
33
- if ( this . modelTopology != null ) {
34
- result = { modelTopology : this . modelTopology , ...result } ;
35
- }
36
- if ( this . weightSpecs != null && this . weightSpecs . length > 0 ) {
37
- result = { weightSpecs : this . weightSpecs , ...result } ;
38
- }
39
- if ( this . weightData != null && this . weightData . byteLength > 0 ) {
40
- result = { weightData : this . weightData , ...result } ;
41
- }
42
- if ( this . trainingConfig != null ) {
43
- result = { trainingConfig : this . trainingConfig , ...result } ;
44
- }
45
- return result ;
28
+ return this . modelArtifacts ;
46
29
}
47
30
}
48
31
@@ -67,7 +50,7 @@ class PassthroughSaver implements IOHandler {
67
50
* modelTopology, weightSpecs, weightData));
68
51
* ```
69
52
*
70
- * @param modelTopology a object containing model topology (i.e., parsed from
53
+ * @param modelArtifacts a object containing model topology (i.e., parsed from
71
54
* the JSON format).
72
55
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the
73
56
* names, shapes, types, and quantization of the weight data.
@@ -78,13 +61,39 @@ class PassthroughSaver implements IOHandler {
78
61
* @returns A passthrough `IOHandler` that simply loads the provided data.
79
62
*/
80
63
export function fromMemory (
81
- modelTopology : { } , weightSpecs ?: WeightsManifestEntry [ ] ,
64
+ modelArtifacts : { } | ModelArtifacts , weightSpecs ?: WeightsManifestEntry [ ] ,
82
65
weightData ?: ArrayBuffer , trainingConfig ?: TrainingConfig ) : IOHandler {
83
- // TODO(cais): The arguments should probably be consolidated into a single
84
- // object, with proper deprecation process. Even though this function isn't
85
- // documented, it is public and being used by some downstream libraries.
86
- return new PassthroughLoader (
87
- modelTopology , weightSpecs , weightData , trainingConfig ) ;
66
+ if ( arguments . length === 1 ) {
67
+ const isModelArtifacts =
68
+ ( modelArtifacts as ModelArtifacts ) . modelTopology != null ||
69
+ ( modelArtifacts as ModelArtifacts ) . weightSpecs != null ;
70
+ if ( isModelArtifacts ) {
71
+ return new PassthroughLoader ( modelArtifacts as ModelArtifacts ) ;
72
+ } else {
73
+ // Legacy support: with only modelTopology.
74
+ // TODO(cais): Remove this deprecated API.
75
+ console . warn (
76
+ 'Please call tf.io.fromMemory() with only one argument. ' +
77
+ 'The argument should be of type ModelArtifacts. ' +
78
+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
79
+ 'deprecated and will be removed in a future release.' ) ;
80
+ return new PassthroughLoader ( { modelTopology : modelArtifacts as { } } ) ;
81
+ }
82
+ } else {
83
+ // Legacy support.
84
+ // TODO(cais): Remove this deprecated API.
85
+ console . warn (
86
+ 'Please call tf.io.fromMemory() with only one argument. ' +
87
+ 'The argument should be of type ModelArtifacts. ' +
88
+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
89
+ 'deprecated and will be removed in a future release.' ) ;
90
+ return new PassthroughLoader ( {
91
+ modelTopology : modelArtifacts as { } ,
92
+ weightSpecs,
93
+ weightData,
94
+ trainingConfig
95
+ } ) ;
96
+ }
88
97
}
89
98
90
99
/**
0 commit comments