@@ -2941,7 +2941,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src,
2941
2941
// squeeze direction axis out. This is safe because it is not bi-directional node.
2942
2942
2943
2943
std::vector<int64_t > shape ({ (int64_t )NDShape::FreeDimension, BatchSizeProcessor::FreeBatchSize (), hidden_size });
2944
-
2944
+
2945
2945
onnxruntime::Node *squeezedLSTMNode = InsertReshapeNodeToCNTKFunction (src, lstmNode, shape, graph, nodeOutputName);
2946
2946
2947
2947
functionNodes.emplace (src, squeezedLSTMNode);
@@ -3883,8 +3883,24 @@ onnxruntime::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const Funct
3883
3883
3884
3884
// We need to name reshape node's output arg with LSTM output name.
3885
3885
// Thus we need to give LSTM node output a different name.
3886
- auto outputArgs = node->OutputDefs ();
3886
+ auto inputNodeArgs = node->OutputDefs ();
3887
+
3888
+ // 1. transpose [seq, dir, batch, feature] to [seq, batch, dir, feature]
3889
+
3890
+ std::vector<int64_t > perm (inputNodeArgs[0 ]->Shape ()->dim_size ());
3891
+ std::generate (perm.begin (), perm.end (), [axis = 0 ]() mutable { return axis++; });
3892
+ std::swap (perm[1 ], perm[2 ]);
3893
+
3894
+ std::vector<int64_t > transposeOutputShape = ToINTS (*inputNodeArgs[0 ]->TypeAsProto ());
3895
+ std::swap (transposeOutputShape[1 ], transposeOutputShape[2 ]);
3896
+ onnx::TypeProto transposeOutputArgType = ToTypeProto (transposeOutputShape, false );
3897
+ UpdateONNXType (src->Outputs ()[0 ].GetDataType (), transposeOutputArgType);
3887
3898
3899
+ Node* transposeNode = AddTransposeNode (const_cast <onnxruntime::NodeArg &>(*inputNodeArgs[0 ]),
3900
+ graph, perm, transposeOutputArgType, node->Name () + " _transpose_" + nodeOutputName);
3901
+
3902
+ NodeArg* transposeOutputNodeArg = const_cast <NodeArg *>(transposeNode->OutputDefs ().at (0 ));
3903
+ // 2. reshape [seq, batch, dir, feature] to [seq, batch, dir * feature]
3888
3904
std::string lstmToReshapeNodeArgName = nodeOutputName;
3889
3905
std::vector<int64_t > reshapeShape = shape;
3890
3906
std::vector<int64_t > outputShape = shape;
@@ -3902,7 +3918,7 @@ onnxruntime::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const Funct
3902
3918
onnxruntime::NodeArg *outputArg = &graph->GetOrCreateNodeArg (lstmToReshapeNodeArgName, &typeProto);
3903
3919
3904
3920
auto reshapeNode = AddReshapeNodeImpl (graph, nodeName + string (" _reshape" ),
3905
- const_cast <NodeArg *>(outputArgs. at ( 0 )) , outputArg, reshapeShape);
3921
+ transposeOutputNodeArg , outputArg, reshapeShape);
3906
3922
3907
3923
return reshapeNode;
3908
3924
}
@@ -7948,11 +7964,28 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const F
7948
7964
transposeOutputArgType, inputArgs[0 ]->Name () + " _transpose_out" );
7949
7965
auto transposedOutputArgs = functionNodeTransposed->OutputDefs ();
7950
7966
7951
- std::vector<int64_t > newShape = Cast<size_t , int64_t >(src->Output ().Shape ().Dimensions ());
7952
- newShape.insert (newShape.begin (), BatchSizeProcessor::FreeBatchSize ());
7953
- newShape.insert (newShape.begin (), NDShape::FreeDimension);
7967
+ NodeArg& inputNodeArg = const_cast <NodeArg&>(*transposedOutputArgs.at (0 ));
7968
+
7969
+ // it is required that batch dimension can be edited after ONNX export.
7970
+ // here shape[1](the batch dimension) is set to 0 to indicate that it shall take input's dimension.
7971
+ // this reshape semantic makes it easier for model editor to change batch size.
7972
+ // output shape shall still be [sequence, 1, feature]
7973
+
7974
+ std::vector<int64_t > reshapeShape = Cast<size_t , int64_t >(src->Output ().Shape ().Dimensions ());
7975
+ // newShape.insert(newShape.begin(), BatchSizeProcessor::FreeBatchSize());
7976
+ reshapeShape.insert (reshapeShape.begin (), 0 );
7977
+ reshapeShape.insert (reshapeShape.begin (), NDShape::FreeDimension);
7978
+
7979
+ std::vector<int64_t > reshapeOutputShape = reshapeShape;
7980
+ reshapeOutputShape[1 ] = BatchSizeProcessor::FreeBatchSize ();
7954
7981
const std::string reshapedOutArgName = finalOutputNodeArgName;
7955
- auto functionNodeReshaped = AddReshapeNode (const_cast <NodeArg&>(*transposedOutputArgs.at (0 )), newShape, reshapedOutArgName, graph);
7982
+ onnx::TypeProto typeProto = ToTypeProto (reshapeOutputShape, false );
7983
+ google::protobuf::int32 elemType = inputNodeArg.TypeAsProto ()->tensor_type ().elem_type ();
7984
+ typeProto.mutable_tensor_type ()->set_elem_type (elemType);
7985
+ NodeArg& outputNodeArg = graph->GetOrCreateNodeArg (reshapedOutArgName, &typeProto);
7986
+
7987
+ auto functionNodeReshaped = AddReshapeNodeImpl (graph, inputNodeArg.Name () + string (" _reshape_to_" ) + reshapedOutArgName,
7988
+ &inputNodeArg, &outputNodeArg, reshapeShape);
7956
7989
7957
7990
functionNodes.emplace (src, functionNodeReshaped);
7958
7991
return functionNodeReshaped;
0 commit comments