Skip to content

Commit 6a689e9

Browse files
Vivek Ranejhseu
authored andcommitted
MKL support for the transpose op (tensorflow#7748)
* Adding MKL support for the transpose op * Fixed build issues for bazel test * Ran buildifier on core/kernels/BUILD * Adding MKL support for the transpose op * Fixed build issues for bazel test * Ran buildifier on core/kernels/BUILD * Fixed rebase issue (buildifier) with build file * BUILD file fixes
1 parent ed1a779 commit 6a689e9

File tree

5 files changed

+105
-5
lines changed

5 files changed

+105
-5
lines changed

tensorflow/core/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,8 @@ cc_library(
702702
"//tensorflow/core/kernels:math_not_windows",
703703
"//tensorflow/core/kernels:quantized_ops",
704704
]) + if_mkl([
705-
"//tensorflow/core/kernels:mkl_ops",
706705
"//tensorflow/core/kernels:mkl_conv_op",
706+
"//tensorflow/core/kernels:mkl_matmul_op",
707707
"//tensorflow/core/kernels:mkl_tfconv_op",
708708
]),
709709
)
@@ -2040,7 +2040,7 @@ if_mkl(
20402040
"//tensorflow/cc:scope",
20412041
"//tensorflow/cc:sendrecv_ops",
20422042
"//tensorflow/core/kernels:mkl_conv_op",
2043-
"//tensorflow/core/kernels:mkl_ops",
2043+
"//tensorflow/core/kernels:mkl_matmul_op",
20442044
"//tensorflow/core/kernels:mkl_tfconv_op",
20452045
"//tensorflow/core/kernels:ops_util",
20462046
"//third_party/eigen3",

tensorflow/core/kernels/BUILD

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,15 @@ tf_kernel_library(
690690

691691
tf_kernel_library(
692692
name = "transpose_op",
693-
prefix = "transpose_op",
694-
deps = ARRAY_DEPS,
693+
srcs = [
694+
"transpose_op.cc",
695+
] + if_mkl([
696+
"mkl_transpose_op.cc",
697+
]),
698+
hdrs = ["transpose_op.h"],
699+
deps = ARRAY_DEPS + if_mkl([
700+
"//third_party/mkl:intel_binary_blob",
701+
]),
695702
)
696703

697704
tf_kernel_library(
@@ -4365,7 +4372,7 @@ tf_cc_test(
43654372

43664373
if_mkl(
43674374
tf_kernel_library(
4368-
name = "mkl_ops",
4375+
name = "mkl_matmul_op",
43694376
prefix = "mkl_matmul",
43704377
deps = [
43714378
":math",
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// See docs in ../ops/array_ops.cc.
17+
18+
#ifdef INTEL_MKL
19+
#define EIGEN_USE_THREADS
20+
21+
#include "tensorflow/core/kernels/transpose_op.h"
22+
#include "tensorflow/core/kernels/transpose_functor.h"
23+
#include "third_party/mkl/include/mkl_trans.h"
24+
25+
namespace tensorflow {
26+
27+
// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
28+
// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
29+
// shuffles the dimensions of the input tensor according to permutation.
30+
//
31+
// Specifically, the returned tensor output meets the following condition:
32+
// 1) output.dims() == input.dims();
33+
// 2) output.dim_size(i) == input.dim_size(perm[i]);
34+
// 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
35+
// input.tensor<T, N>(j_0, j_1, ..., j_N-1),
36+
// where i_s == j_{perm[s]}
37+
//
38+
// REQUIRES: perm is a vector of int32.
39+
// REQUIRES: input.dims() == perm.size().
40+
// REQUIRES: perm is a permutation.
41+
42+
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
43+
gtl::ArraySlice<int32> perm,
44+
Tensor* out) {
45+
if (in.dims() == 2 && in.dtype() == DT_FLOAT) {
46+
float* user_o = out->flat<float>().data();
47+
const float* user_i = in.flat<float>().data();
48+
49+
// Documentation here: https://software.intel.com/en-us/node/520863
50+
// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
51+
// alpha (for scaling), array, dist_bet_adjacent_cols/rows
52+
// (source), array, dist_bet_adjacent_cols/rows (dest))
53+
mkl_somatcopy('R', 'T', in.dim_size(0), in.dim_size(1), 1,
54+
user_i, in.dim_size(1),
55+
user_o, in.dim_size(0));
56+
57+
return Status::OK();
58+
}
59+
60+
// Fallback to eigen if transpose parameters not supported by MKL
61+
typedef Eigen::ThreadPoolDevice CPUDevice;
62+
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
63+
out);
64+
} // MklTransposeCpuOp::DoTranspose
65+
} // namespace tensorflow
66+
67+
#endif // INTEL_MKL

tensorflow/core/kernels/transpose_op.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,20 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
180180
out);
181181
}
182182

183+
#ifdef INTEL_MKL
184+
#define REGISTER(T) \
185+
REGISTER_KERNEL_BUILDER(Name("Transpose") \
186+
.Device(DEVICE_CPU) \
187+
.TypeConstraint<T>("T") \
188+
.TypeConstraint<int32>("Tperm") \
189+
.HostMemory("perm"), \
190+
MklTransposeCpuOp);
191+
TF_CALL_ALL_TYPES(REGISTER);
192+
REGISTER(bfloat16);
193+
#undef REGISTER
194+
195+
#else // INTEL_MKL
196+
183197
#define REGISTER(T) \
184198
REGISTER_KERNEL_BUILDER(Name("Transpose") \
185199
.Device(DEVICE_CPU) \
@@ -190,6 +204,7 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
190204
TF_CALL_ALL_TYPES(REGISTER)
191205
REGISTER(bfloat16);
192206
#undef REGISTER
207+
#endif // INTEL_MKL
193208

194209
#if GOOGLE_CUDA
195210
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,

tensorflow/core/kernels/transpose_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ class TransposeCpuOp : public TransposeOp {
4141
gtl::ArraySlice<int32> perm, Tensor* out) override;
4242
};
4343

44+
#ifdef INTEL_MKL
45+
class MklTransposeCpuOp : public TransposeOp {
46+
public:
47+
explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
48+
49+
protected:
50+
Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
51+
gtl::ArraySlice<int32> perm, Tensor* out) override;
52+
};
53+
#endif // INTEL_MKL
54+
4455
class TransposeGpuOp : public TransposeOp {
4556
public:
4657
explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}

0 commit comments

Comments
 (0)