@@ -19,7 +19,7 @@ import {ENV} from '../environment';
19
19
import { KernelBackend } from '../kernels/backend' ;
20
20
import { Tensor } from '../tensor' ;
21
21
import { NamedTensorMap } from '../tensor_types' ;
22
- import { assertTypesMatch } from '../tensor_util' ;
22
+ import { makeTypesMatch } from '../tensor_util' ;
23
23
import { convertToTensor } from '../tensor_util_env' ;
24
24
import { TensorLike , upcastType } from '../types' ;
25
25
import * as util from '../util' ;
@@ -53,9 +53,9 @@ import {neg} from './unary_ops';
53
53
*/
54
54
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
55
55
function add_ < T extends Tensor > ( a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
56
- const $a = convertToTensor ( a , 'a' , 'add' ) ;
57
- const $b = convertToTensor ( b , 'b' , 'add' ) ;
58
- assertTypesMatch ( $a , $b ) ;
56
+ let $a = convertToTensor ( a , 'a' , 'add' ) ;
57
+ let $b = convertToTensor ( b , 'b' , 'add' ) ;
58
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
59
59
60
60
const outShape =
61
61
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
@@ -172,9 +172,9 @@ function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
172
172
*/
173
173
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
174
174
function sub_ < T extends Tensor > ( a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
175
- const $a = convertToTensor ( a , 'a' , 'sub' ) ;
176
- const $b = convertToTensor ( b , 'b' , 'sub' ) ;
177
- assertTypesMatch ( $a , $b ) ;
175
+ let $a = convertToTensor ( a , 'a' , 'sub' ) ;
176
+ let $b = convertToTensor ( b , 'b' , 'sub' ) ;
177
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
178
178
179
179
const outShape =
180
180
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
@@ -318,9 +318,9 @@ function powStrict_<T extends Tensor>(base: T, exp: Tensor): T {
318
318
*/
319
319
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
320
320
function mul_ < T extends Tensor > ( a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
321
- const $a = convertToTensor ( a , 'a' , 'mul' ) ;
322
- const $b = convertToTensor ( b , 'b' , 'mul' ) ;
323
- assertTypesMatch ( $a , $b ) ;
321
+ let $a = convertToTensor ( a , 'a' , 'mul' ) ;
322
+ let $b = convertToTensor ( b , 'b' , 'mul' ) ;
323
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
324
324
325
325
const outShape =
326
326
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
@@ -391,9 +391,9 @@ function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
391
391
*/
392
392
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
393
393
function div_ < T extends Tensor > ( a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
394
- const $a = convertToTensor ( a , 'a' , 'div' ) ;
395
- const $b = convertToTensor ( b , 'b' , 'div' ) ;
396
- assertTypesMatch ( $a , $b ) ;
394
+ let $a = convertToTensor ( a , 'a' , 'div' ) ;
395
+ let $b = convertToTensor ( b , 'b' , 'div' ) ;
396
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
397
397
398
398
let forwardFunc : ( backend : KernelBackend ) => Tensor ;
399
399
if ( $a . dtype === 'int32' && $b . dtype === 'int32' ) {
@@ -454,9 +454,9 @@ function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
454
454
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
455
455
function floorDiv_ < T extends Tensor > (
456
456
a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
457
- const $a = convertToTensor ( a , 'a' , 'floorDiv' ) ;
458
- const $b = convertToTensor ( b , 'b' , 'floorDiv' ) ;
459
- assertTypesMatch ( $a , $b ) ;
457
+ let $a = convertToTensor ( a , 'a' , 'floorDiv' ) ;
458
+ let $b = convertToTensor ( b , 'b' , 'floorDiv' ) ;
459
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
460
460
461
461
const forwardFunc = ( backend : KernelBackend ) => backend . floorDiv ( $a , $b ) ;
462
462
const outShape =
@@ -526,9 +526,9 @@ function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
526
526
*/
527
527
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
528
528
function mod_ < T extends Tensor > ( a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
529
- const $a = convertToTensor ( a , 'a' , 'mod' ) ;
530
- const $b = convertToTensor ( b , 'b' , 'mod' ) ;
531
- assertTypesMatch ( $a , $b ) ;
529
+ let $a = convertToTensor ( a , 'a' , 'mod' ) ;
530
+ let $b = convertToTensor ( b , 'b' , 'mod' ) ;
531
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
532
532
533
533
const outShape =
534
534
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
@@ -598,14 +598,13 @@ function minimum_<T extends Tensor>(
598
598
a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
599
599
let $a = convertToTensor ( a , 'a' , 'minimum' ) ;
600
600
let $b = convertToTensor ( b , 'b' , 'minimum' ) ;
601
- assertTypesMatch ( $a , $b ) ;
601
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
602
602
603
603
if ( $a . dtype === 'bool' ) {
604
604
$a = $a . toInt ( ) ;
605
- }
606
- if ( $b . dtype === 'bool' ) {
607
605
$b = $b . toInt ( ) ;
608
606
}
607
+
609
608
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
610
609
const der = ( dy : Tensor ) => {
611
610
const derA = ( ) => dy . mul ( $a . lessEqual ( $b ) . toFloat ( ) ) ;
@@ -660,14 +659,13 @@ function maximum_<T extends Tensor>(
660
659
a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
661
660
let $a = convertToTensor ( a , 'a' , 'maximum' ) ;
662
661
let $b = convertToTensor ( b , 'b' , 'maximum' ) ;
663
- assertTypesMatch ( $a , $b ) ;
662
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
664
663
665
664
if ( $a . dtype === 'bool' ) {
666
665
$a = $a . toInt ( ) ;
667
- }
668
- if ( $b . dtype === 'bool' ) {
669
666
$b = $b . toInt ( ) ;
670
667
}
668
+
671
669
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
672
670
const der = ( dy : Tensor ) => {
673
671
const derA = ( ) => dy . mul ( $a . greaterEqual ( $b ) . toFloat ( ) ) ;
@@ -721,9 +719,9 @@ function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
721
719
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
722
720
function squaredDifference_ < T extends Tensor > (
723
721
a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
724
- const $a = convertToTensor ( a , 'a' , 'squaredDifference' ) ;
725
- const $b = convertToTensor ( b , 'b' , 'squaredDifference' ) ;
726
- assertTypesMatch ( $a , $b ) ;
722
+ let $a = convertToTensor ( a , 'a' , 'squaredDifference' ) ;
723
+ let $b = convertToTensor ( b , 'b' , 'squaredDifference' ) ;
724
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
727
725
728
726
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
729
727
const der = ( dy : Tensor ) => {
@@ -772,9 +770,9 @@ function squaredDifferenceStrict_<T extends Tensor>(
772
770
/** @doc {heading: 'Operations', subheading: 'Basic math'} */
773
771
function atan2_ < T extends Tensor > (
774
772
a : Tensor | TensorLike , b : Tensor | TensorLike ) : T {
775
- const $a = convertToTensor ( a , 'a' , 'atan2' ) ;
776
- const $b = convertToTensor ( b , 'b' , 'atan2' ) ;
777
- assertTypesMatch ( $a , $b ) ;
773
+ let $a = convertToTensor ( a , 'a' , 'atan2' ) ;
774
+ let $b = convertToTensor ( b , 'b' , 'atan2' ) ;
775
+ [ $a , $b ] = makeTypesMatch ( $a , $b ) ;
778
776
779
777
const outShape =
780
778
broadcast_util . assertAndGetBroadcastShape ( $a . shape , $b . shape ) ;
0 commit comments