Skip to content

Commit 706d164

Browse files
author
Kangyi Zhang
authored
Export types used in tfjs-data (tensorflow#1626)
BUG * save * save
1 parent fd6934d commit 706d164

File tree

3 files changed

+16
-56
lines changed

3 files changed

+16
-56
lines changed

src/index.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import * as math from './math';
3030
import * as browser from './ops/browser';
3131
import * as serialization from './serialization';
3232
import {setOpHandler} from './tensor';
33+
import * as tensor_util from './tensor_util';
3334
import * as test_util from './test_util';
3435
import * as util from './util';
3536
import {version} from './version';
@@ -46,8 +47,8 @@ export {Optimizer} from './optimizers/optimizer';
4647
export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
4748
export {SGDOptimizer} from './optimizers/sgd_optimizer';
4849
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
49-
export {GradSaveFunc, NamedTensorMap} from './tensor_types';
50-
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap} from './types';
50+
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
51+
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap, TensorLike} from './types';
5152

5253
export * from './ops/ops';
5354
export {LSTMCellFunc} from './ops/lstm';
@@ -76,7 +77,17 @@ export {
7677
};
7778

7879
// Second level exports.
79-
export {browser, environment, io, math, serialization, test_util, util, webgl};
80+
export {
81+
browser,
82+
environment,
83+
io,
84+
math,
85+
serialization,
86+
test_util,
87+
util,
88+
webgl,
89+
tensor_util
90+
};
8091

8192
// Backend specific.
8293
export {KernelBackend, BackendTimingInfo, DataMover, DataStorage} from './kernels/backend';

src/tensor_util.ts

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {Tensor} from './tensor';
19-
import {NamedTensorMap, TensorContainer, TensorContainerArray} from './tensor_types';
19+
import {TensorContainer, TensorContainerArray} from './tensor_types';
2020
import {upcastType} from './types';
2121
import {assert} from './util';
2222

@@ -44,33 +44,6 @@ export function isTensorInList(tensor: Tensor, tensorList: Tensor[]): boolean {
4444
return false;
4545
}
4646

47-
export function flattenNameArrayMap(
48-
nameArrayMap: Tensor|NamedTensorMap, keys?: string[]): Tensor[] {
49-
const xs: Tensor[] = [];
50-
if (nameArrayMap instanceof Tensor) {
51-
xs.push(nameArrayMap);
52-
} else {
53-
const xMap = nameArrayMap as {[xName: string]: Tensor};
54-
for (let i = 0; i < keys.length; i++) {
55-
xs.push(xMap[keys[i]]);
56-
}
57-
}
58-
return xs;
59-
}
60-
61-
export function unflattenToNameArrayMap(
62-
keys: string[], flatArrays: Tensor[]): NamedTensorMap {
63-
if (keys.length !== flatArrays.length) {
64-
throw new Error(
65-
`Cannot unflatten Tensor[], keys and arrays are not of same length.`);
66-
}
67-
const result: NamedTensorMap = {};
68-
for (let i = 0; i < keys.length; i++) {
69-
result[keys[i]] = flatArrays[i];
70-
}
71-
return result;
72-
}
73-
7447
/**
7548
* Extracts any `Tensor`s found within the provided object.
7649
*

src/tensor_util_test.ts

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import * as tf from './index';
1919
import {describeWithFlags} from './jasmine_util';
2020
import {Tensor} from './tensor';
21-
import {NamedTensorMap} from './tensor_types';
22-
import {flattenNameArrayMap, getTensorsInContainer, isTensorInList, unflattenToNameArrayMap} from './tensor_util';
21+
import {getTensorsInContainer, isTensorInList} from './tensor_util';
2322
import {convertToTensor} from './tensor_util_env';
2423
import {ALL_ENVS, expectArraysClose, expectArraysEqual} from './test_util';
2524

@@ -39,29 +38,6 @@ describe('tensor_util.isTensorInList', () => {
3938
});
4039
});
4140

42-
describe('tensor_util.flattenNameArrayMap', () => {
43-
it('basic', () => {
44-
const a = tf.scalar(1);
45-
const b = tf.scalar(3);
46-
const c = tf.tensor1d([1, 2, 3]);
47-
48-
const map: NamedTensorMap = {a, b, c};
49-
expect(flattenNameArrayMap(map, Object.keys(map))).toEqual([a, b, c]);
50-
});
51-
});
52-
53-
describe('tensor_util.unflattenToNameArrayMap', () => {
54-
it('basic', () => {
55-
const a = tf.scalar(1);
56-
const b = tf.scalar(3);
57-
const c = tf.tensor1d([1, 2, 3]);
58-
59-
expect(unflattenToNameArrayMap(['a', 'b', 'c'], [
60-
a, b, c
61-
])).toEqual({a, b, c});
62-
});
63-
});
64-
6541
describe('getTensorsInContainer', () => {
6642
it('null input returns empty tensor', () => {
6743
const results = getTensorsInContainer(null);

0 commit comments

Comments
 (0)