Skip to content

Commit 04011e5

Browse files
author
Nikhil Thorat
authored
Add TensorLike to chaining API typings and add unit tests. (tensorflow#1413)
FEATURE
1 parent 9b4ce5b commit 04011e5

22 files changed

+932
-412
lines changed

src/ops/arithmetic_test.ts

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,26 @@ describeWithFlags('div', ALL_ENVS, () => {
2929
expectArraysClose(r, [1, 1, 1, 1, 2.5, 6 / 5]);
3030
});
3131

32+
it('TensorLike', () => {
33+
const a = [0, 1, -2, -4, 4, -4];
34+
const b = [0.15, 0.2, 0.25, 0.5, 0.7, 1.2];
35+
const result = tf.div(a, b);
36+
37+
expect(result.shape).toEqual([6]);
38+
expectArraysClose(
39+
result, [0, 5.0, -8.0, -8.0, 5.714285850524902, -3.3333332538604736]);
40+
});
41+
42+
it('TensorLike chained', () => {
43+
const a = tf.tensor1d([0, 1, -2, -4, 4, -4]);
44+
const b = [0.15, 0.2, 0.25, 0.5, 0.7, 1.2];
45+
const result = a.div(b);
46+
47+
expect(result.shape).toEqual(a.shape);
48+
expectArraysClose(
49+
result, [0, 5.0, -8.0, -8.0, 5.714285850524902, -3.3333332538604736]);
50+
});
51+
3252
it('integer division implements floor divide', () => {
3353
const a = tf.tensor1d([-6, -6, -5, -4, -3, -3, 3, 3, 2], 'int32');
3454
const c = tf.tensor1d([-2, 2, 3, 2, -3, 3, 2, 3, 2], 'int32');
@@ -344,6 +364,26 @@ describeWithFlags('mul', ALL_ENVS, () => {
344364
expectArraysClose(result, expected);
345365
});
346366

367+
it('TensorLike', () => {
368+
const a = [[1, 2], [-3, -4]];
369+
const b = [[5, 3], [4, -7]];
370+
const expected = [5, 6, -12, 28];
371+
const result = tf.mul(a, b);
372+
373+
expect(result.shape).toEqual([2, 2]);
374+
expectArraysClose(result, expected);
375+
});
376+
377+
it('TensorLike chained', () => {
378+
const a = tf.tensor2d([1, 2, -3, -4], [2, 2]);
379+
const b = [[5, 3], [4, -7]];
380+
const expected = [5, 6, -12, 28];
381+
const result = a.mul(b);
382+
383+
expect(result.shape).toEqual([2, 2]);
384+
expectArraysClose(result, expected);
385+
});
386+
347387
it('broadcasting tensors', () => {
348388
const a = tf.tensor2d([1, 2, -3, -4], [2, 2]);
349389
const b = tf.scalar(2);
@@ -565,6 +605,28 @@ describeWithFlags('pow', ALL_ENVS, () => {
565605
expectArraysClose(result, expected, 0.01);
566606
});
567607

608+
it('TensorLike', () => {
609+
const a = [1, 2, 3];
610+
const exp = 2;
611+
612+
const result = tf.pow(a, exp);
613+
614+
expect(result.shape).toEqual([3]);
615+
expect(result.dtype).toBe('float32');
616+
expectArraysEqual(result, [1, 4, 9]);
617+
});
618+
619+
it('TensorLike chained', () => {
620+
const a = tf.tensor1d([1, 2, 3]);
621+
const exp = 2;
622+
623+
const result = a.pow(exp);
624+
625+
expect(result.shape).toEqual([3]);
626+
expect(result.dtype).toBe('float32');
627+
expectArraysEqual(result, [1, 4, 9]);
628+
});
629+
568630
it('int32^int32 returns int32', () => {
569631
const a = tf.tensor1d([1, 2, 3], 'int32');
570632
const exp = tf.scalar(2, 'int32');
@@ -879,6 +941,26 @@ describeWithFlags('add', ALL_ENVS, () => {
879941
expectArraysClose(result, expected);
880942
});
881943

944+
it('TensorLike', () => {
945+
const a = [2, 5, 1];
946+
const b = [4, 2, -1];
947+
948+
const result = tf.add(a, b);
949+
950+
const expected = [6, 7, 0];
951+
expectArraysClose(result, expected);
952+
});
953+
954+
it('TensorLike chained', () => {
955+
const a = tf.tensor1d([2, 5, 1]);
956+
const b = [4, 2, -1];
957+
958+
const result = a.add(b);
959+
960+
const expected = [6, 7, 0];
961+
expectArraysClose(result, expected);
962+
});
963+
882964
it('A + B propagates NaNs', () => {
883965
const a = tf.tensor1d([2, 5, NaN]);
884966
const b = tf.tensor1d([4, 2, -1]);
@@ -1185,6 +1267,26 @@ describeWithFlags('sub', ALL_ENVS, () => {
11851267
expectArraysClose(result, expected);
11861268
});
11871269

1270+
it('TensorLike', () => {
1271+
const a = [2, 5, 1];
1272+
const b = [4, 2, -1];
1273+
1274+
const result = tf.sub(a, b);
1275+
1276+
const expected = [-2, 3, 2];
1277+
expectArraysClose(result, expected);
1278+
});
1279+
1280+
it('TensorLike chained', () => {
1281+
const a = tf.tensor1d([2, 5, 1]);
1282+
const b = [4, 2, -1];
1283+
1284+
const result = a.sub(b);
1285+
1286+
const expected = [-2, 3, 2];
1287+
expectArraysClose(result, expected);
1288+
});
1289+
11881290
it('A - B propagates NaNs', () => {
11891291
const a = tf.tensor1d([2, 5, 1]);
11901292
const b = tf.tensor1d([4, NaN, -1]);

src/ops/array_ops.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,8 @@ function pad_<T extends Tensor>(
714714
* @param axis The axis to stack along. Defaults to 0 (the first dim).
715715
*/
716716
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
717-
function stack_<T extends Tensor>(tensors: T[]|TensorLike[], axis = 0): Tensor {
717+
function stack_<T extends Tensor>(
718+
tensors: Array<T|TensorLike>, axis = 0): Tensor {
718719
const $tensors = convertToTensorArray(tensors, 'tensors', 'stack');
719720

720721
util.assert($tensors.length >= 1, 'Pass at least one tensor to tf.stack');

src/ops/binary_ops.ts

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ function addN_<T extends Tensor>(tensors: Array<T|TensorLike>): T {
139139
* @param a The first Tensor to add element-wise.
140140
* @param b The second Tensor to add element-wise.
141141
*/
142-
function addStrict_<T extends Tensor>(a: T, b: T): T {
143-
util.assertShapesMatch(a.shape, b.shape, 'Error in addStrict: ');
144-
return a.add(b);
142+
function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
143+
const $a = convertToTensor(a, 'a', 'addStrict');
144+
const $b = convertToTensor(b, 'b', 'addStrict');
145+
util.assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: ');
146+
return $a.add($b);
145147
}
146148

147149
/**
@@ -209,9 +211,11 @@ function sub_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
209211
* @param a The first Tensor to subtract element-wise.
210212
* @param b The second Tensor to subtract element-wise.
211213
*/
212-
function subStrict_<T extends Tensor>(a: T, b: T): T {
213-
util.assertShapesMatch(a.shape, b.shape, 'Error in subStrict: ');
214-
return a.sub(b);
214+
function subStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
215+
const $a = convertToTensor(a, 'a', 'subStrict');
216+
const $b = convertToTensor(b, 'b', 'subStrict');
217+
util.assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: ');
218+
return $a.sub($b);
215219
}
216220

217221
/**
@@ -353,9 +357,11 @@ function mul_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
353357
* @param b The first tensor to multiply. Must have the same
354358
* dtype as `a`.
355359
*/
356-
function mulStrict_<T extends Tensor>(a: T, b: T): T {
357-
util.assertShapesMatch(a.shape, b.shape, 'Error in multiplyStrict: ');
358-
return a.mul(b) as T;
360+
function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
361+
const $a = convertToTensor(a, 'a', 'mul');
362+
const $b = convertToTensor(b, 'b', 'mul');
363+
util.assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: ');
364+
return $a.mul($b) as T;
359365
}
360366

361367
/**
@@ -485,9 +491,11 @@ function floorDiv_<T extends Tensor>(
485491
* @param a The first tensor as the numerator for element-wise division.
486492
* @param b The second tensor as the denominator for element-wise division.
487493
*/
488-
function divStrict_<T extends Tensor>(a: T, b: T): T {
489-
util.assertShapesMatch(a.shape, b.shape, 'Error in divideStrict: ');
490-
return a.div(b) as T;
494+
function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
495+
const $a = convertToTensor(a, 'a', 'div');
496+
const $b = convertToTensor(b, 'b', 'div');
497+
util.assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: ');
498+
return $a.div($b) as T;
491499
}
492500

493501
/**
@@ -553,9 +561,11 @@ function mod_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
553561
* @param a The first tensor.
554562
* @param b The second tensor. Must have the same dtype as `a`.
555563
*/
556-
function modStrict_<T extends Tensor>(a: T, b: T): T {
557-
util.assertShapesMatch(a.shape, b.shape, 'Error in modStrict: ');
558-
return a.mod(b);
564+
function modStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
565+
const $a = convertToTensor(a, 'a', 'modStrict');
566+
const $b = convertToTensor(b, 'b', 'modStrict');
567+
util.assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: ');
568+
return $a.mod($b);
559569
}
560570

561571
/**
@@ -613,9 +623,11 @@ function minimum_<T extends Tensor>(
613623
* @param a The first tensor.
614624
* @param b The second tensor. Must have the same dtype as `a`.
615625
*/
616-
function minimumStrict_<T extends Tensor>(a: T, b: T): T {
617-
util.assertShapesMatch(a.shape, b.shape, 'Error in minimumStrict: ');
618-
return a.minimum(b);
626+
function minimumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
627+
const $a = convertToTensor(a, 'a', 'minimumStrict');
628+
const $b = convertToTensor(b, 'b', 'minimumStrict');
629+
util.assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: ');
630+
return $a.minimum($b);
619631
}
620632

621633
/**
@@ -673,9 +685,11 @@ function maximum_<T extends Tensor>(
673685
* @param a The first tensor.
674686
* @param b The second tensor. Must have the same dtype as `a`.
675687
*/
676-
function maximumStrict_<T extends Tensor>(a: T, b: T): T {
677-
util.assertShapesMatch(a.shape, b.shape, 'Error in maximumStrict: ');
678-
return a.maximum(b);
688+
function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
689+
const $a = convertToTensor(a, 'a', 'maximumStrict');
690+
const $b = convertToTensor(b, 'b', 'maximumStrict');
691+
util.assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: ');
692+
return $a.maximum($b);
679693
}
680694

681695
/**
@@ -731,10 +745,13 @@ function squaredDifference_<T extends Tensor>(
731745
* @param a The first tensor.
732746
* @param b The second tensor. Must have the same type as `a`.
733747
*/
734-
function squaredDifferenceStrict_<T extends Tensor>(a: T, b: T): T {
748+
function squaredDifferenceStrict_<T extends Tensor>(
749+
a: T|TensorLike, b: T|TensorLike): T {
750+
const $a = convertToTensor(a, 'a', 'squaredDifferenceStrict');
751+
const $b = convertToTensor(b, 'b', 'squaredDifferenceStrict');
735752
util.assertShapesMatch(
736-
a.shape, b.shape, 'Error in squaredDifferenceStrict: ');
737-
return a.squaredDifference(b);
753+
$a.shape, $b.shape, 'Error in squaredDifferenceStrict: ');
754+
return $a.squaredDifference($b);
738755
}
739756

740757
/**

0 commit comments

Comments
 (0)