Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit a516745

Browse files
Jakub Kaczmarzykdsmilkov
Jakub Kaczmarzyk
authored andcommitted
Add tf.conv3dTranspose (#1629)
this PR proposes adding a conv3dTranspose op (with tests included). This PR is driven by my group's desire to use a pre-trained 3D U-Net with Tensorflow JS. FEATURE
1 parent e70a33f commit a516745

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

‎.vscode/settings.json‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"[javascript]": {
1919
"editor.formatOnSave": true
2020
},
21+
"editor.defaultFormatter": "xaver.clang-format",
2122
"editor.rulers": [80],
2223
"clang-format.style": "Google",
2324
"files.insertFinalNewline": true,

‎src/ops/conv.ts‎

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,36 @@ function conv3dDerFilter_<T extends Tensor4D|Tensor5D>(
937937
backend => backend.conv3dDerFilter(x5D, dy5D, convInfo), {x5D, dy5D});
938938
}
939939

940+
/**
941+
* Computes the transposed 3D convolution of a volume, also known as a
942+
* deconvolution.
943+
*
944+
* @param x The input image, of rank 5 or rank 4, of shape
945+
* `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
946+
* @param filter The filter, rank 4, of shape
947+
* `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
948+
* `inDepth` must match `inDepth` in `x`.
949+
* @param outputShape Output shape, of rank 5 or rank 4:
950+
* `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
951+
* assumed.
952+
* @param strides The strides of the original convolution:
953+
* `[strideDepth, strideHeight, strideWidth]`.
954+
* @param pad The type of padding algorithm used in the non-transpose version
955+
* of the op.
956+
*/
957+
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
958+
function conv3dTranspose_<T extends Tensor4D|Tensor5D>(
959+
x: T|TensorLike, filter: Tensor5D|TensorLike,
960+
outputShape:
961+
[number, number, number, number,
962+
number]|[number, number, number, number],
963+
strides: [number, number, number]|number, pad: 'valid'|'same'): T {
964+
const $x = convertToTensor(x, 'x', 'conv3dTranspose');
965+
const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
966+
967+
return conv3dDerInput_(outputShape, $x, $filter, strides, pad);
968+
}
969+
940970
export const conv1d = op({conv1d_});
941971
export const conv2d = op({conv2d_});
942972
export const conv3d = op({conv3d_});
@@ -945,3 +975,4 @@ export const conv2dDerInput = op({conv2dDerInput_});
945975
export const depthwiseConv2d = op({depthwiseConv2d_});
946976
export const separableConv2d = op({separableConv2d_});
947977
export const conv2dTranspose = op({conv2dTranspose_});
978+
export const conv3dTranspose = op({conv3dTranspose_});

‎src/ops/conv3d_transpose_test.ts‎

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '../index';
19+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20+
import {expectArraysClose} from '../test_util';
21+
22+
describeWithFlags('conv3dTranspose', ALL_ENVS, () => {
23+
// Reference Python TensorFlow code
24+
// ```python
25+
// import numpy as np
26+
// import tensorflow as tf
27+
// tf.enable_eager_execution()
28+
// x = np.array([2], dtype = np.float32).reshape(1, 1, 1, 1, 1)
29+
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2, 2, 2,
30+
// 1, 1)
31+
// tf.nn.conv3d_transpose(x, w, output_shape=[1, 2, 2, 2, 1], padding='VALID')
32+
// ```
33+
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid', async () => {
34+
const origInputDepth = 1;
35+
const origOutputDepth = 1;
36+
const inputShape: [number, number, number, number] =
37+
[1, 1, 1, origOutputDepth];
38+
const fSize = 2;
39+
const origPad = 'valid';
40+
const origStride = 1;
41+
42+
const x = tf.tensor4d([2], inputShape);
43+
const w = tf.tensor5d(
44+
[5, 4, 8, 7, 1, 2, 6, 3],
45+
[fSize, fSize, fSize, origInputDepth, origOutputDepth]);
46+
47+
const result = tf.conv3dTranspose(x, w, [2, 2, 2, 1], origStride, origPad);
48+
const expected = [10, 8, 16, 14, 2, 4, 12, 6];
49+
50+
expect(result.shape).toEqual([2, 2, 2, 1]);
51+
expectArraysClose(await result.data(), expected);
52+
});
53+
54+
// Reference Python TensorFlow code
55+
// ```python
56+
// import numpy as np
57+
// import tensorflow as tf
58+
// tf.enable_eager_execution()
59+
// x = np.array([2, 3], dtype = np.float32).reshape(2, 1, 1, 1, 1, 1)
60+
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2,
61+
// 2, 2, 1, 1)
62+
// tf.nn.conv3d_transpose(x, w, output_shape=[2, 2, 2, 2, 1], padding='VALID')
63+
// ```
64+
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid, batch=2', async () => {
65+
const origInputDepth = 1;
66+
const origOutputDepth = 1;
67+
const inputShape: [number, number, number, number, number] =
68+
[2, 1, 1, 1, origOutputDepth];
69+
const fSize = 2;
70+
const origPad = 'valid';
71+
const origStride = 1;
72+
73+
const x = tf.tensor5d([2, 3], inputShape);
74+
const w = tf.tensor5d(
75+
[5, 4, 8, 7, 1, 2, 6, 3],
76+
[fSize, fSize, fSize, origInputDepth, origOutputDepth]);
77+
78+
const result =
79+
tf.conv3dTranspose(x, w, [2, 2, 2, 2, 1], origStride, origPad);
80+
const expected = [10, 8, 16, 14, 2, 4, 12, 6, 15, 12, 24, 21, 3, 6, 18, 9];
81+
82+
expect(result.shape).toEqual([2, 2, 2, 2, 1]);
83+
expectArraysClose(await result.data(), expected);
84+
});
85+
});

‎src/tests.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ import './ops/conv2d_separable_test';
5555
import './ops/conv2d_test';
5656
import './ops/conv2d_transpose_test';
5757
import './ops/conv3d_test';
58+
import './ops/conv3d_transpose_test';
5859
import './ops/conv_util_test';
5960
import './ops/diag_test';
6061
import './ops/dropout_test';

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /