Skip to content

Commit f217045

Browse files
authored
Add string dtype to Tensor (tensorflow#1408)
Add `string` dtype to `Tensor`. This opens the door for adding Python's [`string_ops`](https://www.tensorflow.org/api_docs/python/tf/strings) to TensorFlow.js, which are used for text-based models, and for adding pre-processing layers that operate on strings. Details: - dtype was not added as a generic to the Tensor class in order to keep compiler errors simple and code backwards compatible. - dataSync() can be optionally typed to cast its result. E.g. `t.dataSync<'string'>()` returns `string[]` while `t.dataSync()` returns `TypedArray` for backwards compatibility. - `layers` and `converter` pass with this build. `node` has 30ish failed tests since `string` is an unknown dtype. - Only `clone`, `reshape` and `cast` work with strings at this point to keep this PR small. Other ops will get the functionality in a follow-up PR. - Added unit tests to assert that numeric ops throw on string tensors. - Backends now should support dtype `string` in their `register/write/read` methods. - Added a vscode config to do debugging directly from vscode FEATURE
1 parent a25776b commit f217045

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1764
-468
lines changed

.vscode/launch.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"type": "chrome",
9+
"request": "attach",
10+
"name": "Attach Karma Chrome",
11+
"address": "localhost",
12+
"port": 9333,
13+
"pathMapping": {
14+
"/": "${workspaceRoot}",
15+
"/base/": "${workspaceRoot}/"
16+
}
17+
}
18+
]
19+
}

karma.conf.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ module.exports = function(config) {
8989
chrome_with_swift_shader: {
9090
base: 'Chrome',
9191
flags: ['--blacklist-accelerated-compositing', '--blacklist-webgl']
92-
}
92+
},
93+
chrome_debugging:
94+
{base: 'Chrome', flags: ['--remote-debugging-port=9333']}
9395
},
9496
client: {jasmine: {random: false}, args: args}
9597
});

src/buffer_test.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from './index';
19+
import {describeWithFlags} from './jasmine_util';
20+
import {ALL_ENVS, expectArraysClose, expectArraysEqual} from './test_util';
21+
22+
describeWithFlags('tf.buffer', ALL_ENVS, () => {
23+
it('float32', () => {
24+
const buff = tf.buffer([1, 2, 3], 'float32');
25+
buff.set(1.3, 0, 0, 0);
26+
buff.set(2.9, 0, 1, 0);
27+
expect(buff.get(0, 0, 0)).toBeCloseTo(1.3);
28+
expect(buff.get(0, 0, 1)).toBeCloseTo(0);
29+
expect(buff.get(0, 0, 2)).toBeCloseTo(0);
30+
expect(buff.get(0, 1, 0)).toBeCloseTo(2.9);
31+
expect(buff.get(0, 1, 1)).toBeCloseTo(0);
32+
expect(buff.get(0, 1, 2)).toBeCloseTo(0);
33+
expectArraysClose(buff.toTensor(), [1.3, 0, 0, 2.9, 0, 0]);
34+
expectArraysClose(buff.values, new Float32Array([1.3, 0, 0, 2.9, 0, 0]));
35+
});
36+
37+
it('int32', () => {
38+
const buff = tf.buffer([2, 3], 'int32');
39+
buff.set(1.3, 0, 0);
40+
buff.set(2.1, 1, 1);
41+
expect(buff.get(0, 0)).toEqual(1);
42+
expect(buff.get(0, 1)).toEqual(0);
43+
expect(buff.get(0, 2)).toEqual(0);
44+
expect(buff.get(1, 0)).toEqual(0);
45+
expect(buff.get(1, 1)).toEqual(2);
46+
expect(buff.get(1, 2)).toEqual(0);
47+
expectArraysClose(buff.toTensor(), [1, 0, 0, 0, 2, 0]);
48+
expectArraysClose(buff.values, new Int32Array([1, 0, 0, 0, 2, 0]));
49+
});
50+
51+
it('bool', () => {
52+
const buff = tf.buffer([4], 'bool');
53+
buff.set(true, 1);
54+
buff.set(true, 2);
55+
expect(buff.get(0)).toBeFalsy();
56+
expect(buff.get(1)).toBeTruthy();
57+
expect(buff.get(2)).toBeTruthy();
58+
expect(buff.get(3)).toBeFalsy();
59+
expectArraysClose(buff.toTensor(), [0, 1, 1, 0]);
60+
expectArraysClose(buff.values, new Uint8Array([0, 1, 1, 0]));
61+
});
62+
63+
it('string', () => {
64+
const buff = tf.buffer([2, 2], 'string');
65+
buff.set('first', 0, 0);
66+
buff.set('third', 1, 0);
67+
expect(buff.get(0, 0)).toEqual('first');
68+
expect(buff.get(0, 1)).toBeFalsy();
69+
expect(buff.get(1, 0)).toEqual('third');
70+
expect(buff.get(1, 1)).toBeFalsy();
71+
expectArraysEqual(buff.toTensor(), ['first', null, 'third', null]);
72+
});
73+
});

src/engine.ts

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode
2121
import {DataId, Tensor, Tensor3D, Variable} from './tensor';
2222
import {NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
2323
import {getTensorsInContainer, isTensorInList} from './tensor_util';
24-
import {DataType, TypedArray} from './types';
24+
import {DataType, DataValues} from './types';
2525
import * as util from './util';
26-
import {makeOnesTypedArray, now, sizeFromShape} from './util';
26+
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
2727

2828
/**
2929
* A function that computes an output. The save function is for saving tensors
@@ -42,7 +42,7 @@ export type CustomGradientFunc<T extends Tensor> = (...args: Tensor[]) => {
4242

4343
export type MemoryInfo = {
4444
numTensors: number; numDataBuffers: number; numBytes: number;
45-
unreliable?: boolean;
45+
unreliable?: boolean; reasons: string[];
4646
};
4747

4848
type KernelProfile = {
@@ -85,6 +85,7 @@ export class Engine implements TensorManager, DataMover {
8585
private nextTapeNodeId = 0;
8686
private numBytes = 0;
8787
private numTensors = 0;
88+
private numStringTensors = 0;
8889
private numDataBuffers = 0;
8990

9091
private profiling = false;
@@ -102,6 +103,7 @@ export class Engine implements TensorManager, DataMover {
102103

103104
private tensorInfo = new WeakMap<DataId, {
104105
backend: KernelBackend,
106+
bytes: number,
105107
dtype: DataType,
106108
shape: number[],
107109
refCount: number
@@ -250,18 +252,26 @@ export class Engine implements TensorManager, DataMover {
250252
this.tensorInfo.get(a.dataId).refCount :
251253
0;
252254
this.numTensors++;
255+
if (a.dtype === 'string') {
256+
this.numStringTensors++;
257+
}
253258
if (refCount === 0) {
254259
this.numDataBuffers++;
255260

256-
// Don't count bytes for complex numbers as they are counted by their
257-
// components.
258-
if (a.dtype !== 'complex64') {
259-
this.numBytes +=
260-
util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
261+
// Bytes for complex numbers are counted by their components. Bytes for
262+
// string tensors are counted when writing values.
263+
let bytes = 0;
264+
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
265+
bytes = util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
261266
}
262-
this.tensorInfo.set(
263-
a.dataId,
264-
{backend: this.backend, dtype: a.dtype, shape: a.shape, refCount: 0});
267+
this.tensorInfo.set(a.dataId, {
268+
backend: this.backend,
269+
dtype: a.dtype,
270+
shape: a.shape,
271+
bytes,
272+
refCount: 0
273+
});
274+
this.numBytes += bytes;
265275
this.backend.register(a.dataId, a.shape, a.dtype);
266276
}
267277
this.tensorInfo.get(a.dataId).refCount++;
@@ -285,17 +295,19 @@ export class Engine implements TensorManager, DataMover {
285295
this.keepTensors.delete(a.id);
286296
}
287297
this.numTensors--;
288-
const refCount = this.tensorInfo.get(a.dataId).refCount;
298+
if (a.dtype === 'string') {
299+
this.numStringTensors--;
300+
}
301+
const info = this.tensorInfo.get(a.dataId);
302+
const refCount = info.refCount;
289303
if (refCount <= 1) {
290-
const info = this.tensorInfo.get(a.dataId);
291-
info.backend.disposeData(a.dataId);
292-
this.numDataBuffers--;
293304
// Don't count bytes for complex numbers as they are counted by their
294305
// components.
295306
if (a.dtype !== 'complex64') {
296-
this.numBytes -=
297-
util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
307+
this.numBytes -= info.bytes;
298308
}
309+
this.numDataBuffers--;
310+
info.backend.disposeData(a.dataId);
299311
this.tensorInfo.delete(a.dataId);
300312
} else {
301313
this.tensorInfo.get(a.dataId).refCount--;
@@ -318,6 +330,15 @@ export class Engine implements TensorManager, DataMover {
318330
info.numTensors = this.numTensors;
319331
info.numDataBuffers = this.numDataBuffers;
320332
info.numBytes = this.numBytes;
333+
if (this.numStringTensors > 0) {
334+
info.unreliable = true;
335+
if (info.reasons == null) {
336+
info.reasons = [];
337+
}
338+
info.reasons.push(
339+
'Memory usage by string tensors is approximate ' +
340+
'(2 bytes per character)');
341+
}
321342
return info;
322343
}
323344

@@ -457,6 +478,9 @@ export class Engine implements TensorManager, DataMover {
457478
f: () => T, xs: Tensor[], dy?: T,
458479
allowNoGradients = false): {value: T, grads: Tensor[]} {
459480
util.assert(xs.length > 0, 'gradients() received an empty list of xs.');
481+
if (dy != null && dy.dtype !== 'float32') {
482+
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
483+
}
460484

461485
return this.tidy('gradients', () => {
462486
const y = f();
@@ -537,8 +561,15 @@ export class Engine implements TensorManager, DataMover {
537561
}
538562

539563
// Forwarding to backend.
540-
write(dataId: DataId, values: TypedArray): void {
564+
write(dataId: DataId, values: DataValues): void {
541565
const info = this.tensorInfo.get(dataId);
566+
// Bytes for string tensors are counted when writing.
567+
if (info.dtype === 'string') {
568+
const newBytes = bytesFromStringArray(values as string[]);
569+
this.numBytes += newBytes - info.bytes;
570+
info.bytes = newBytes;
571+
}
572+
542573
if (this.backend !== info.backend) {
543574
// Delete the tensor from the old backend and move it to the new backend.
544575
info.backend.disposeData(dataId);
@@ -547,12 +578,12 @@ export class Engine implements TensorManager, DataMover {
547578
}
548579
this.backend.write(dataId, values);
549580
}
550-
readSync(dataId: DataId): TypedArray {
581+
readSync(dataId: DataId): DataValues {
551582
// Route the read to the correct backend.
552583
const info = this.tensorInfo.get(dataId);
553584
return info.backend.readSync(dataId);
554585
}
555-
read(dataId: DataId): Promise<TypedArray> {
586+
read(dataId: DataId): Promise<DataValues> {
556587
// Route the read to the correct backend.
557588
const info = this.tensorInfo.get(dataId);
558589
return info.backend.read(dataId);

src/engine_test.ts

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {describeWithFlags} from './jasmine_util';
2020
import {MathBackendCPU} from './kernels/backend_cpu';
2121
import {MathBackendWebGL} from './kernels/backend_webgl';
2222
import {Tensor} from './tensor';
23-
import {ALL_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';
23+
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';
2424

2525
describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
2626
it('debug mode does not error when no nans', () => {
@@ -361,6 +361,74 @@ describeWithFlags('memory', ALL_ENVS, () => {
361361
expect(sum.dtype).toBe('int32');
362362
expectArraysClose(sum, [1 + 1 + 0 + 1]);
363363
});
364+
365+
it('string tensor', () => {
366+
const a = tf.tensor([['a', 'bb'], ['c', 'd']]);
367+
368+
expect(tf.memory().numTensors).toBe(1);
369+
expect(tf.memory().numBytes).toBe(10); // 5 letters, each 2 bytes.
370+
371+
a.dispose();
372+
373+
expect(tf.memory().numTensors).toBe(0);
374+
expect(tf.memory().numBytes).toBe(0);
375+
});
376+
377+
it('unreliable is true for string tensors', () => {
378+
tf.tensor('a');
379+
const mem = tf.memory();
380+
expect(mem.unreliable).toBe(true);
381+
const expectedReason = 'Memory usage by string tensors is approximate ' +
382+
'(2 bytes per character)';
383+
expect(mem.reasons.indexOf(expectedReason) >= 0).toBe(true);
384+
});
385+
});
386+
387+
describeWithFlags('memory webgl', WEBGL_ENVS, () => {
388+
it('unreliable is falsy/not present when all tensors are numeric', () => {
389+
tf.tensor(1);
390+
const mem = tf.memory();
391+
expect(mem.numTensors).toBe(1);
392+
expect(mem.numDataBuffers).toBe(1);
393+
expect(mem.numBytes).toBe(4);
394+
expect(mem.unreliable).toBeFalsy();
395+
});
396+
});
397+
398+
describeWithFlags('memory cpu', CPU_ENVS, () => {
399+
it('unreliable is true due to auto gc', () => {
400+
tf.tensor(1);
401+
const mem = tf.memory();
402+
expect(mem.numTensors).toBe(1);
403+
expect(mem.numDataBuffers).toBe(1);
404+
expect(mem.numBytes).toBe(4);
405+
expect(mem.unreliable).toBe(true);
406+
407+
const expectedReason =
408+
'The reported memory is an upper bound. Due to automatic garbage ' +
409+
'collection, the true allocated memory may be less.';
410+
expect(mem.reasons.indexOf(expectedReason) >= 0).toBe(true);
411+
});
412+
413+
it('unreliable is true due to both auto gc and string tensors', () => {
414+
tf.tensor(1);
415+
tf.tensor('a');
416+
417+
const mem = tf.memory();
418+
expect(mem.numTensors).toBe(2);
419+
expect(mem.numDataBuffers).toBe(2);
420+
expect(mem.numBytes).toBe(6);
421+
expect(mem.unreliable).toBe(true);
422+
423+
const expectedReasonGC =
424+
'The reported memory is an upper bound. Due to automatic garbage ' +
425+
'collection, the true allocated memory may be less.';
426+
expect(mem.reasons.indexOf(expectedReasonGC) >= 0).toBe(true);
427+
const expectedReasonString =
428+
'Memory usage by string tensors is approximate ' +
429+
'(2 bytes per character)';
430+
expect(mem.reasons.indexOf(expectedReasonString) >= 0).toBe(true);
431+
});
364432
});
365433

366434
describeWithFlags('profile', ALL_ENVS, () => {

src/environment.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ export class Environment {
101101
* (undisposed) at this time, which is ≤ the number of tensors
102102
* (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
103103
* data buffer with `a`).
104-
* - `unreliable`: `Optional` `boolean`:
105-
* - On WebGL, not present (always reliable).
106-
* - On CPU, true. Due to automatic garbage collection, these numbers
107-
* represent undisposed tensors, i.e. not wrapped in `tidy()`, or
108-
* lacking a call to `tensor.dispose()`.
104+
* - `unreliable`: True if the memory usage is unreliable. See `reasons` when
105+
* `unrealible` is true.
106+
* - `reasons`: `string[]`, reasons why the memory is unreliable, present if
107+
* `unreliable` is true.
109108
*/
110109
/** @doc {heading: 'Performance', subheading: 'Memory'} */
111110
static memory(): MemoryInfo {

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
4848
export {SGDOptimizer} from './optimizers/sgd_optimizer';
4949
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
5050
export {NamedTensorMap} from './tensor_types';
51-
export {DataType, Rank, ShapeMap} from './types';
51+
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap} from './types';
5252

5353
export * from './ops/ops';
5454
export {LSTMCellFunc} from './ops/lstm';

src/jasmine_util.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ export interface TestEnv {
9393

9494
export let TEST_ENVS: TestEnv[] = [
9595
{
96-
name: 'test-webgl1',
96+
name: 'webgl1',
9797
factory: () => new MathBackendWebGL(),
9898
features: {
9999
'WEBGL_VERSION': 1,
@@ -102,7 +102,7 @@ export let TEST_ENVS: TestEnv[] = [
102102
}
103103
},
104104
{
105-
name: 'test-webgl2',
105+
name: 'webgl2',
106106
factory: () => new MathBackendWebGL(),
107107
features: {
108108
'WEBGL_VERSION': 2,
@@ -111,7 +111,7 @@ export let TEST_ENVS: TestEnv[] = [
111111
}
112112
},
113113
{
114-
name: 'test-cpu',
114+
name: 'cpu',
115115
factory: () => new MathBackendCPU(),
116116
features: {'HAS_WEBGL': false}
117117
}

0 commit comments

Comments
 (0)