@@ -34,8 +34,8 @@ function generateCaseInputs(totalSizeTensor: number, totalSizeFilter: number) {
34
34
return { input : inp , filter : filt } ;
35
35
}
36
36
37
- describeWithFlags ( 'im2col ' , PACKED_ENVS , ( ) => {
38
- it ( 'should not leak memory' , ( ) => {
37
+ describeWithFlags ( 'conv to matmul ' , PACKED_ENVS , ( ) => {
38
+ it ( 'im2col should not leak memory' , ( ) => {
39
39
const inputDepth = 1 ;
40
40
const inputShape : [ number , number , number ] = [ 2 , 2 , inputDepth ] ;
41
41
const outputDepth = 1 ;
@@ -55,6 +55,26 @@ describeWithFlags('im2col', PACKED_ENVS, () => {
55
55
56
56
expect ( endNumBytes - startNumBytes ) . toEqual ( 4 ) ;
57
57
} ) ;
58
+
59
+ it ( 'pointwise conv should work when matmul is unpacked' , ( ) => {
60
+ const inputDepth =
61
+ 1001 ; // this number must be greater than MATMUL_SHARED_DIM_THRESHOLD
62
+ // for matmul to be unpacked
63
+ const inputShape : [ number , number , number ] = [ 3 , 3 , inputDepth ] ;
64
+ const outputDepth = 1 ;
65
+ const fSize = 1 ;
66
+ const pad = 'same' ;
67
+ const stride : [ number , number ] = [ 1 , 1 ] ;
68
+
69
+ let x = tf . randomNormal ( inputShape ) as tf . Tensor3D ;
70
+ x = x . add ( 1 ) ; // this packs x so we can test the case where we mistakenly
71
+ // want to avoid expensive reshape in pointwise conv2d even
72
+ // though matmul is unpacked
73
+ const w =
74
+ tf . randomNormal ( [ fSize , fSize , inputDepth , outputDepth ] ) as tf . Tensor4D ;
75
+
76
+ expect ( ( ) => tf . conv2d ( x , w , stride , pad ) ) . not . toThrow ( ) ;
77
+ } ) ;
58
78
} ) ;
59
79
60
80
describeWithFlags ( 'conv2d' , ALL_ENVS , ( ) => {
0 commit comments