Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit a516745

Browse files
Jakub Kaczmarzykdsmilkov
authored andcommitted
Add tf.conv3dTranspose (#1629)
this PR proposes adding a conv3dTranspose op (with tests included). This PR is driven by my group's desire to use a pre-trained 3D U-Net with Tensorflow JS. FEATURE
1 parent e70a33f commit a516745

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"[javascript]": {
1919
"editor.formatOnSave": true
2020
},
21+
"editor.defaultFormatter": "xaver.clang-format",
2122
"editor.rulers": [80],
2223
"clang-format.style": "Google",
2324
"files.insertFinalNewline": true,

src/ops/conv.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,36 @@ function conv3dDerFilter_<T extends Tensor4D|Tensor5D>(
937937
backend => backend.conv3dDerFilter(x5D, dy5D, convInfo), {x5D, dy5D});
938938
}
939939

940+
/**
941+
* Computes the transposed 3D convolution of a volume, also known as a
942+
* deconvolution.
943+
*
944+
* @param x The input image, of rank 5 or rank 4, of shape
945+
* `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
946+
* @param filter The filter, rank 4, of shape
947+
* `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
948+
* `inDepth` must match `inDepth` in `x`.
949+
* @param outputShape Output shape, of rank 5 or rank 4:
950+
* `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
951+
* assumed.
952+
* @param strides The strides of the original convolution:
953+
* `[strideDepth, strideHeight, strideWidth]`.
954+
* @param pad The type of padding algorithm used in the non-transpose version
955+
* of the op.
956+
*/
957+
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
958+
function conv3dTranspose_<T extends Tensor4D|Tensor5D>(
959+
x: T|TensorLike, filter: Tensor5D|TensorLike,
960+
outputShape:
961+
[number, number, number, number,
962+
number]|[number, number, number, number],
963+
strides: [number, number, number]|number, pad: 'valid'|'same'): T {
964+
const $x = convertToTensor(x, 'x', 'conv3dTranspose');
965+
const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
966+
967+
return conv3dDerInput_(outputShape, $x, $filter, strides, pad);
968+
}
969+
940970
export const conv1d = op({conv1d_});
941971
export const conv2d = op({conv2d_});
942972
export const conv3d = op({conv3d_});
@@ -945,3 +975,4 @@ export const conv2dDerInput = op({conv2dDerInput_});
945975
export const depthwiseConv2d = op({depthwiseConv2d_});
946976
export const separableConv2d = op({separableConv2d_});
947977
export const conv2dTranspose = op({conv2dTranspose_});
978+
export const conv3dTranspose = op({conv3dTranspose_});

src/ops/conv3d_transpose_test.ts

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '../index';
19+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20+
import {expectArraysClose} from '../test_util';
21+
22+
describeWithFlags('conv3dTranspose', ALL_ENVS, () => {
23+
// Reference Python TensorFlow code
24+
// ```python
25+
// import numpy as np
26+
// import tensorflow as tf
27+
// tf.enable_eager_execution()
28+
// x = np.array([2], dtype = np.float32).reshape(1, 1, 1, 1, 1)
29+
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2, 2, 2,
30+
// 1, 1)
31+
// tf.nn.conv3d_transpose(x, w, output_shape=[1, 2, 2, 2, 1], padding='VALID')
32+
// ```
33+
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid', async () => {
34+
const origInputDepth = 1;
35+
const origOutputDepth = 1;
36+
const inputShape: [number, number, number, number] =
37+
[1, 1, 1, origOutputDepth];
38+
const fSize = 2;
39+
const origPad = 'valid';
40+
const origStride = 1;
41+
42+
const x = tf.tensor4d([2], inputShape);
43+
const w = tf.tensor5d(
44+
[5, 4, 8, 7, 1, 2, 6, 3],
45+
[fSize, fSize, fSize, origInputDepth, origOutputDepth]);
46+
47+
const result = tf.conv3dTranspose(x, w, [2, 2, 2, 1], origStride, origPad);
48+
const expected = [10, 8, 16, 14, 2, 4, 12, 6];
49+
50+
expect(result.shape).toEqual([2, 2, 2, 1]);
51+
expectArraysClose(await result.data(), expected);
52+
});
53+
54+
// Reference Python TensorFlow code
55+
// ```python
56+
// import numpy as np
57+
// import tensorflow as tf
58+
// tf.enable_eager_execution()
59+
// x = np.array([2, 3], dtype = np.float32).reshape(2, 1, 1, 1, 1, 1)
60+
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2,
61+
// 2, 2, 1, 1)
62+
// tf.nn.conv3d_transpose(x, w, output_shape=[2, 2, 2, 2, 1], padding='VALID')
63+
// ```
64+
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid, batch=2', async () => {
65+
const origInputDepth = 1;
66+
const origOutputDepth = 1;
67+
const inputShape: [number, number, number, number, number] =
68+
[2, 1, 1, 1, origOutputDepth];
69+
const fSize = 2;
70+
const origPad = 'valid';
71+
const origStride = 1;
72+
73+
const x = tf.tensor5d([2, 3], inputShape);
74+
const w = tf.tensor5d(
75+
[5, 4, 8, 7, 1, 2, 6, 3],
76+
[fSize, fSize, fSize, origInputDepth, origOutputDepth]);
77+
78+
const result =
79+
tf.conv3dTranspose(x, w, [2, 2, 2, 2, 1], origStride, origPad);
80+
const expected = [10, 8, 16, 14, 2, 4, 12, 6, 15, 12, 24, 21, 3, 6, 18, 9];
81+
82+
expect(result.shape).toEqual([2, 2, 2, 2, 1]);
83+
expectArraysClose(await result.data(), expected);
84+
});
85+
});

src/tests.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ import './ops/conv2d_separable_test';
5555
import './ops/conv2d_test';
5656
import './ops/conv2d_transpose_test';
5757
import './ops/conv3d_test';
58+
import './ops/conv3d_transpose_test';
5859
import './ops/conv_util_test';
5960
import './ops/diag_test';
6061
import './ops/dropout_test';

0 commit comments

Comments
 (0)