Skip to content

Commit abf0985

Browse files
committed
transpose batch with dir before reshape to [seq, batch, dir*hidden] for RNN ops
1 parent ef67317 commit abf0985

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,7 +2941,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src,
29412941
// squeeze direction axis out. This is safe because it is not bi-directional node.
29422942

29432943
std::vector<int64_t> shape({ (int64_t)NDShape::FreeDimension, BatchSizeProcessor::FreeBatchSize(), hidden_size });
2944-
2944+
29452945
onnxruntime::Node *squeezedLSTMNode = InsertReshapeNodeToCNTKFunction(src, lstmNode, shape, graph, nodeOutputName);
29462946

29472947
functionNodes.emplace(src, squeezedLSTMNode);
@@ -3883,8 +3883,24 @@ onnxruntime::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const Funct
38833883

38843884
// We need to name reshape node's output arg with LSTM output name.
38853885
// Thus we need to give LSTM node output a different name.
3886-
auto outputArgs = node->OutputDefs();
3886+
auto inputNodeArgs = node->OutputDefs();
3887+
3888+
// 1. transpose [seq, dir, batch, feature] to [seq, batch, dir, feature]
3889+
3890+
std::vector<int64_t> perm(inputNodeArgs[0]->Shape()->dim_size());
3891+
std::generate(perm.begin(), perm.end(), [axis = 0]() mutable { return axis++; });
3892+
std::swap(perm[1], perm[2]);
3893+
3894+
std::vector<int64_t> transposeOutputShape = ToINTS(*inputNodeArgs[0]->TypeAsProto());
3895+
std::swap(transposeOutputShape[1], transposeOutputShape[2]);
3896+
onnx::TypeProto transposeOutputArgType = ToTypeProto(transposeOutputShape, false);
3897+
UpdateONNXType(src->Outputs()[0].GetDataType(), transposeOutputArgType);
38873898

3899+
Node* transposeNode = AddTransposeNode(const_cast<onnxruntime::NodeArg &>(*inputNodeArgs[0]),
3900+
graph, perm, transposeOutputArgType, node->Name() + "_transpose_" + nodeOutputName);
3901+
3902+
NodeArg* transposeOutputNodeArg = const_cast<NodeArg *>(transposeNode->OutputDefs().at(0));
3903+
// 2. reshape [seq, batch, dir, feature] to [seq, batch, dir * feature]
38883904
std::string lstmToReshapeNodeArgName = nodeOutputName;
38893905
std::vector<int64_t> reshapeShape = shape;
38903906
std::vector<int64_t> outputShape = shape;
@@ -3902,7 +3918,7 @@ onnxruntime::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const Funct
39023918
onnxruntime::NodeArg *outputArg = &graph->GetOrCreateNodeArg(lstmToReshapeNodeArgName, &typeProto);
39033919

39043920
auto reshapeNode = AddReshapeNodeImpl(graph, nodeName + string("_reshape"),
3905-
const_cast<NodeArg *>(outputArgs.at(0)), outputArg, reshapeShape);
3921+
transposeOutputNodeArg, outputArg, reshapeShape);
39063922

39073923
return reshapeNode;
39083924
}
@@ -7948,11 +7964,28 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const F
79487964
transposeOutputArgType, inputArgs[0]->Name() + "_transpose_out");
79497965
auto transposedOutputArgs = functionNodeTransposed->OutputDefs();
79507966

7951-
std::vector<int64_t> newShape = Cast<size_t, int64_t>(src->Output().Shape().Dimensions());
7952-
newShape.insert(newShape.begin(), BatchSizeProcessor::FreeBatchSize());
7953-
newShape.insert(newShape.begin(), NDShape::FreeDimension);
7967+
NodeArg& inputNodeArg = const_cast<NodeArg&>(*transposedOutputArgs.at(0));
7968+
7969+
// it is required that batch dimension can be edited after ONNX export.
7970+
// here shape[1](the batch dimension) is set to 0 to indicate that it shall take input's dimension.
7971+
// this reshape semantic makes it easier for model editor to change batch size.
7972+
// output shape shall still be [sequence, 1, feature]
7973+
7974+
std::vector<int64_t> reshapeShape = Cast<size_t, int64_t>(src->Output().Shape().Dimensions());
7975+
// newShape.insert(newShape.begin(), BatchSizeProcessor::FreeBatchSize());
7976+
reshapeShape.insert(reshapeShape.begin(), 0);
7977+
reshapeShape.insert(reshapeShape.begin(), NDShape::FreeDimension);
7978+
7979+
std::vector<int64_t> reshapeOutputShape = reshapeShape;
7980+
reshapeOutputShape[1] = BatchSizeProcessor::FreeBatchSize();
79547981
const std::string reshapedOutArgName = finalOutputNodeArgName;
7955-
auto functionNodeReshaped = AddReshapeNode(const_cast<NodeArg&>(*transposedOutputArgs.at(0)), newShape, reshapedOutArgName, graph);
7982+
onnx::TypeProto typeProto = ToTypeProto(reshapeOutputShape, false);
7983+
google::protobuf::int32 elemType = inputNodeArg.TypeAsProto()->tensor_type().elem_type();
7984+
typeProto.mutable_tensor_type()->set_elem_type(elemType);
7985+
NodeArg& outputNodeArg = graph->GetOrCreateNodeArg(reshapedOutArgName, &typeProto);
7986+
7987+
auto functionNodeReshaped = AddReshapeNodeImpl(graph, inputNodeArg.Name() + string("_reshape_to_") + reshapedOutArgName,
7988+
&inputNodeArg, &outputNodeArg, reshapeShape);
79567989

79577990
functionNodes.emplace(src, functionNodeReshaped);
79587991
return functionNodeReshaped;

0 commit comments

Comments
 (0)