Skip to content

Commit 8db48a6

Browse files
authored
Allow different dtypes in binary math ops (tensorflow#1432)
Allow users to provide different dtypes in binary arithmetic ops (add/sub/mul/div/...) and matmul, just like in numpy. The dtype of the result is upcasted i.e. matMul(float32, int32) => float32 This will result in release patch 0.14.1, which will fix the breakage in 0.14.0 caused by tensorflow#1408 due to improved dtype inference where tensor(new Int32Array()) is inferred to be int32, and was float32. Fixes tensorflow/tfjs#934, tensorflow/tfjs#966
1 parent 2ff431b commit 8db48a6

File tree

10 files changed

+328
-128
lines changed

10 files changed

+328
-128
lines changed

src/kernels/backend_cpu.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ export class MathBackendCPU implements KernelBackend {
407407
[b.strides[1], 1, b.strides[0]];
408408

409409
const size = leftDim * rightDim;
410-
const result = new Float32Array(batchDim * size);
411-
410+
const result = buffer([batchDim, leftDim, rightDim], a.dtype);
411+
const resVals = result.values as TypedArray;
412412
const blockSize = this.blockSize;
413413

414414
for (let b = 0; b < batchDim; b++) {
@@ -428,15 +428,14 @@ export class MathBackendCPU implements KernelBackend {
428428
sum += aValues[b * aBatch + i * aOuterStep + k * aInnerStep] *
429429
bValues[k * bInnerStep + j * bOuterStep + b * bBatch];
430430
}
431-
result[b * size + (i * rightDim + j)] += sum;
431+
resVals[b * size + (i * rightDim + j)] += sum;
432432
}
433433
}
434434
}
435435
}
436436
}
437437
}
438-
439-
return ops.tensor3d(result, [batchDim, leftDim, rightDim]);
438+
return result.toTensor() as Tensor3D;
440439
}
441440

442441
multiply(a: Tensor, b: Tensor): Tensor {

src/kernels/backend_webgl.ts

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D,
3535
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types';
3636
import * as util from '../util';
3737
import {getTypedArrayFromDType, sizeFromShape} from '../util';
38-
3938
import {DataMover, DataStorage, KernelBackend} from './backend';
4039
import * as backend_util from './backend_util';
4140
import {mergeRealAndImagArrays} from './complex_util';
@@ -682,6 +681,8 @@ export class MathBackendWebGL implements KernelBackend {
682681
return this.multiply(a3D, b3D).sum(axis, true /* keepDims */);
683682
}
684683

684+
const dtype = upcastType(a.dtype, b.dtype);
685+
685686
// TODO(https://github.com/tensorflow/tfjs/issues/693): Support 3D tensors
686687
if (batch === 1) {
687688
const aSqueezed = a.as2D(a.shape[1], a.shape[2]);
@@ -690,13 +691,17 @@ export class MathBackendWebGL implements KernelBackend {
690691
const program = new MatMulPackedProgram(
691692
aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB],
692693
transposeA, transposeB);
694+
const output =
695+
this.makePackedTensor(program.outputShape, dtype) as Tensor2D;
693696
const result =
694-
this.compileAndRun<Tensor2D>(program, [aSqueezed, bSqueezed]);
695-
697+
this.compileAndRun<Tensor2D>(program, [aSqueezed, bSqueezed], output);
696698
return result.reshape([1, result.shape[0], result.shape[1]]);
697699
} else {
698-
return this.compileAndRun(
699-
new MatMulProgram(a.shape, b.shape, transposeA, transposeB), [a, b]);
700+
const program =
701+
new MatMulProgram(a.shape, b.shape, transposeA, transposeB);
702+
const output =
703+
this.makeOutputArray(program.outputShape, dtype) as Tensor3D;
704+
return this.compileAndRun(program, [a, b], output);
700705
}
701706
}
702707

@@ -1517,7 +1522,8 @@ export class MathBackendWebGL implements KernelBackend {
15171522
convInfo.outChannels / convInfo.inChannels === 1) {
15181523
program = new DepthwiseConvPacked2DProgram(convInfo);
15191524
return this.compileAndRun(
1520-
program, [x, filter], this.makePackedTensor(convInfo.outShape));
1525+
program, [x, filter],
1526+
this.makePackedTensor(convInfo.outShape, x.dtype));
15211527
}
15221528

15231529
program = new DepthwiseConv2DProgram(convInfo);
@@ -1769,16 +1775,17 @@ export class MathBackendWebGL implements KernelBackend {
17691775
return Tensor.make(shape, {}, dtype) as T;
17701776
}
17711777

1772-
private makePackedTensor<T extends Tensor>(shape: number[]): T {
1773-
const packedTensor = Tensor.make(shape, {});
1778+
private makePackedTensor<T extends Tensor>(shape: number[], dtype: DataType):
1779+
T {
1780+
const packedTensor = Tensor.make(shape, {}, dtype);
17741781
this.texData.get(packedTensor.dataId).isPacked = true;
17751782
return packedTensor as T;
17761783
}
17771784

17781785
private unpackTensor<T extends Tensor>(input: T): T {
17791786
const program = new UnpackProgram(input.shape);
17801787
return this.compileAndRun(
1781-
program, [input], Tensor.make(program.outputShape, {}));
1788+
program, [input], Tensor.make(program.outputShape, {}, input.dtype));
17821789
}
17831790

17841791
private getBatchDim(shape: number[], dimsToSkip = 2): number {
@@ -1815,7 +1822,8 @@ export class MathBackendWebGL implements KernelBackend {
18151822
pageToCpu = true): K {
18161823
if (output == null) {
18171824
if (program.usesPackedTextures) {
1818-
output = this.makePackedTensor(program.outputShape) as {} as K;
1825+
output = this.makePackedTensor(program.outputShape, inputs[0].dtype) as
1826+
{} as K;
18191827
} else {
18201828
output = this.makeOutputArray(program.outputShape, inputs[0].dtype) as
18211829
{} as K;
@@ -1872,11 +1880,12 @@ export class MathBackendWebGL implements KernelBackend {
18721880
preProcessProgram = new UnpackProgram(input.shape);
18731881
processedInput = this.compileAndRun(
18741882
preProcessProgram, [input],
1875-
Tensor.make(preProcessProgram.outputShape, {}));
1883+
Tensor.make(preProcessProgram.outputShape, {}, input.dtype));
18761884
} else {
18771885
preProcessProgram = new PackProgram(input.shape);
18781886
processedInput = this.compileAndRun(
1879-
preProcessProgram, [input], this.makePackedTensor(input.shape));
1887+
preProcessProgram, [input],
1888+
this.makePackedTensor(input.shape, input.dtype));
18801889
}
18811890

18821891
texData = this.texData.get(processedInput.dataId);

src/ops/arithmetic_test.ts

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,14 @@ describeWithFlags('div', ALL_ENVS, () => {
102102
expectArraysClose(result, expected);
103103
});
104104

105-
it('throws when passed tensors of different types', () => {
106-
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
107-
const b = tf.tensor2d([1, 2, 3, 4, 2, 5], [2, 3], 'int32');
105+
it('upcasts when dtypes dont match', () => {
106+
let res = tf.div(tf.scalar(6, 'int32'), tf.scalar(3, 'float32'));
107+
expect(res.dtype).toBe('float32');
108+
expectArraysClose(res, [2]);
108109

109-
expect(() => tf.div(a, b)).toThrowError();
110-
expect(() => tf.div(b, a)).toThrowError();
110+
res = tf.div(tf.scalar(6, 'int32'), tf.scalar(true, 'bool'));
111+
expect(res.dtype).toBe('int32');
112+
expectArraysClose(res, [6]);
111113
});
112114

113115
it('throws when passed tensors of different shapes', () => {
@@ -580,11 +582,18 @@ describeWithFlags('mul', ALL_ENVS, () => {
580582
expect(() => tf.mul(tf.scalar(1), {} as tf.Tensor))
581583
.toThrowError(/Argument 'b' passed to 'mul' must be a Tensor/);
582584
});
583-
it('throws when dtypes dont match', () => {
584-
expect(() => tf.mul(tf.scalar(1, 'int32'), tf.scalar(1)))
585-
.toThrowError(
586-
// tslint:disable-next-line:max-line-length
587-
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
585+
it('upcasts when dtypes dont match', () => {
586+
let res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(3, 'float32'));
587+
expect(res.dtype).toBe('float32');
588+
expectArraysClose(res, [6]);
589+
590+
res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(true, 'bool'));
591+
expect(res.dtype).toBe('int32');
592+
expectArraysClose(res, [2]);
593+
594+
res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(false, 'bool'));
595+
expect(res.dtype).toBe('int32');
596+
expectArraysClose(res, [0]);
588597
});
589598

590599
it('accepts a tensor-like object', () => {
@@ -1149,11 +1158,26 @@ describeWithFlags('add', ALL_ENVS, () => {
11491158
.toThrowError(/Argument 'b' passed to 'add' must be a Tensor/);
11501159
});
11511160

1152-
it('throws when dtypes dont match', () => {
1153-
expect(() => tf.add(tf.scalar(1, 'int32'), tf.scalar(1)))
1154-
.toThrowError(
1155-
// tslint:disable-next-line:max-line-length
1156-
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
1161+
it('upcasts when dtypes dont match', () => {
1162+
let res = tf.add(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
1163+
expect(res.dtype).toBe('float32');
1164+
expectArraysClose(res, [2]);
1165+
1166+
res = tf.add(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
1167+
expect(res.dtype).toBe('int32');
1168+
expectArraysClose(res, [2]);
1169+
1170+
res = tf.add(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
1171+
expect(res.dtype).toBe('int32');
1172+
expectArraysClose(res, [1]);
1173+
1174+
res = tf.add(tf.complex(4, 7), tf.scalar(1, 'float32'));
1175+
expect(res.dtype).toBe('complex64');
1176+
expectArraysClose(res, [5, 7]);
1177+
1178+
res = tf.add(tf.complex(4, 7), tf.scalar(1, 'int32'));
1179+
expect(res.dtype).toBe('complex64');
1180+
expectArraysClose(res, [5, 7]);
11571181
});
11581182

11591183
it('accepts a tensor-like object', () => {
@@ -1495,18 +1519,26 @@ describeWithFlags('sub', ALL_ENVS, () => {
14951519
expect(() => tf.sub(tf.scalar(1), {} as tf.Tensor))
14961520
.toThrowError(/Argument 'b' passed to 'sub' must be a Tensor/);
14971521
});
1498-
it('throws when dtypes dont match', () => {
1499-
expect(() => tf.sub(tf.scalar(1, 'int32'), tf.scalar(1)))
1500-
.toThrowError(
1501-
// tslint:disable-next-line:max-line-length
1502-
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
1503-
});
1522+
it('upcasts when dtypes dont match', () => {
1523+
let res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
1524+
expect(res.dtype).toBe('float32');
1525+
expectArraysClose(res, [0]);
15041526

1505-
it('throws when dtypes dont match', () => {
1506-
expect(() => tf.sub(tf.scalar(1, 'float32'), tf.complex(1, 2)))
1507-
.toThrowError(
1508-
// tslint:disable-next-line:max-line-length
1509-
/The dtypes of the first\(float32\) and second\(complex64\) input must match/);
1527+
res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
1528+
expect(res.dtype).toBe('int32');
1529+
expectArraysClose(res, [0]);
1530+
1531+
res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
1532+
expect(res.dtype).toBe('int32');
1533+
expectArraysClose(res, [1]);
1534+
1535+
res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'float32'));
1536+
expect(res.dtype).toBe('complex64');
1537+
expectArraysClose(res, [3, 7]);
1538+
1539+
res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'int32'));
1540+
expect(res.dtype).toBe('complex64');
1541+
expectArraysClose(res, [3, 7]);
15101542
});
15111543

15121544
it('accepts a tensor-like object', () => {

src/ops/binary_ops.ts

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {ENV} from '../environment';
1919
import {KernelBackend} from '../kernels/backend';
2020
import {Tensor} from '../tensor';
2121
import {NamedTensorMap} from '../tensor_types';
22-
import {assertTypesMatch} from '../tensor_util';
22+
import {makeTypesMatch} from '../tensor_util';
2323
import {convertToTensor} from '../tensor_util_env';
2424
import {TensorLike, upcastType} from '../types';
2525
import * as util from '../util';
@@ -53,9 +53,9 @@ import {neg} from './unary_ops';
5353
*/
5454
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
5555
function add_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
56-
const $a = convertToTensor(a, 'a', 'add');
57-
const $b = convertToTensor(b, 'b', 'add');
58-
assertTypesMatch($a, $b);
56+
let $a = convertToTensor(a, 'a', 'add');
57+
let $b = convertToTensor(b, 'b', 'add');
58+
[$a, $b] = makeTypesMatch($a, $b);
5959

6060
const outShape =
6161
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
@@ -172,9 +172,9 @@ function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
172172
*/
173173
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
174174
function sub_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
175-
const $a = convertToTensor(a, 'a', 'sub');
176-
const $b = convertToTensor(b, 'b', 'sub');
177-
assertTypesMatch($a, $b);
175+
let $a = convertToTensor(a, 'a', 'sub');
176+
let $b = convertToTensor(b, 'b', 'sub');
177+
[$a, $b] = makeTypesMatch($a, $b);
178178

179179
const outShape =
180180
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
@@ -318,9 +318,9 @@ function powStrict_<T extends Tensor>(base: T, exp: Tensor): T {
318318
*/
319319
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
320320
function mul_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
321-
const $a = convertToTensor(a, 'a', 'mul');
322-
const $b = convertToTensor(b, 'b', 'mul');
323-
assertTypesMatch($a, $b);
321+
let $a = convertToTensor(a, 'a', 'mul');
322+
let $b = convertToTensor(b, 'b', 'mul');
323+
[$a, $b] = makeTypesMatch($a, $b);
324324

325325
const outShape =
326326
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
@@ -391,9 +391,9 @@ function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
391391
*/
392392
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
393393
function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
394-
const $a = convertToTensor(a, 'a', 'div');
395-
const $b = convertToTensor(b, 'b', 'div');
396-
assertTypesMatch($a, $b);
394+
let $a = convertToTensor(a, 'a', 'div');
395+
let $b = convertToTensor(b, 'b', 'div');
396+
[$a, $b] = makeTypesMatch($a, $b);
397397

398398
let forwardFunc: (backend: KernelBackend) => Tensor;
399399
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
@@ -454,9 +454,9 @@ function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
454454
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
455455
function floorDiv_<T extends Tensor>(
456456
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
457-
const $a = convertToTensor(a, 'a', 'floorDiv');
458-
const $b = convertToTensor(b, 'b', 'floorDiv');
459-
assertTypesMatch($a, $b);
457+
let $a = convertToTensor(a, 'a', 'floorDiv');
458+
let $b = convertToTensor(b, 'b', 'floorDiv');
459+
[$a, $b] = makeTypesMatch($a, $b);
460460

461461
const forwardFunc = (backend: KernelBackend) => backend.floorDiv($a, $b);
462462
const outShape =
@@ -526,9 +526,9 @@ function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
526526
*/
527527
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
528528
function mod_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
529-
const $a = convertToTensor(a, 'a', 'mod');
530-
const $b = convertToTensor(b, 'b', 'mod');
531-
assertTypesMatch($a, $b);
529+
let $a = convertToTensor(a, 'a', 'mod');
530+
let $b = convertToTensor(b, 'b', 'mod');
531+
[$a, $b] = makeTypesMatch($a, $b);
532532

533533
const outShape =
534534
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
@@ -598,14 +598,13 @@ function minimum_<T extends Tensor>(
598598
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
599599
let $a = convertToTensor(a, 'a', 'minimum');
600600
let $b = convertToTensor(b, 'b', 'minimum');
601-
assertTypesMatch($a, $b);
601+
[$a, $b] = makeTypesMatch($a, $b);
602602

603603
if ($a.dtype === 'bool') {
604604
$a = $a.toInt();
605-
}
606-
if ($b.dtype === 'bool') {
607605
$b = $b.toInt();
608606
}
607+
609608
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
610609
const der = (dy: Tensor) => {
611610
const derA = () => dy.mul($a.lessEqual($b).toFloat());
@@ -660,14 +659,13 @@ function maximum_<T extends Tensor>(
660659
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
661660
let $a = convertToTensor(a, 'a', 'maximum');
662661
let $b = convertToTensor(b, 'b', 'maximum');
663-
assertTypesMatch($a, $b);
662+
[$a, $b] = makeTypesMatch($a, $b);
664663

665664
if ($a.dtype === 'bool') {
666665
$a = $a.toInt();
667-
}
668-
if ($b.dtype === 'bool') {
669666
$b = $b.toInt();
670667
}
668+
671669
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
672670
const der = (dy: Tensor) => {
673671
const derA = () => dy.mul($a.greaterEqual($b).toFloat());
@@ -721,9 +719,9 @@ function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
721719
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
722720
function squaredDifference_<T extends Tensor>(
723721
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
724-
const $a = convertToTensor(a, 'a', 'squaredDifference');
725-
const $b = convertToTensor(b, 'b', 'squaredDifference');
726-
assertTypesMatch($a, $b);
722+
let $a = convertToTensor(a, 'a', 'squaredDifference');
723+
let $b = convertToTensor(b, 'b', 'squaredDifference');
724+
[$a, $b] = makeTypesMatch($a, $b);
727725

728726
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
729727
const der = (dy: Tensor) => {
@@ -772,9 +770,9 @@ function squaredDifferenceStrict_<T extends Tensor>(
772770
/** @doc {heading: 'Operations', subheading: 'Basic math'} */
773771
function atan2_<T extends Tensor>(
774772
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
775-
const $a = convertToTensor(a, 'a', 'atan2');
776-
const $b = convertToTensor(b, 'b', 'atan2');
777-
assertTypesMatch($a, $b);
773+
let $a = convertToTensor(a, 'a', 'atan2');
774+
let $b = convertToTensor(b, 'b', 'atan2');
775+
[$a, $b] = makeTypesMatch($a, $b);
778776

779777
const outShape =
780778
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);

0 commit comments

Comments
 (0)