@@ -845,6 +845,12 @@ class CNTKToONNXHelper
845
845
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
846
846
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
847
847
848
+ static onnxruntime::Node* CreatePoolingNode (const FunctionPtr& src,
849
+ onnxruntime::Graph* graph,
850
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
851
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
852
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex);
853
+
848
854
static onnxruntime::Node* CreateConvolutionNode (const FunctionPtr& src,
849
855
onnxruntime::Graph* graph,
850
856
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@@ -5629,7 +5635,12 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
5629
5635
else
5630
5636
return CreateConvolutionNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5631
5637
}
5632
-
5638
+ else if (src->OpName () == L" Pooling" && src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ())
5639
+ {
5640
+ // in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
5641
+ // ONNX spec of [N, C, H, W] shape requirement.
5642
+ return CreatePoolingNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5643
+ }
5633
5644
//
5634
5645
// If this block node equivalent to a primitive ONNX OP, then treated as such.
5635
5646
// And just maps its argument to ONNX node.
@@ -7087,7 +7098,8 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
7087
7098
auto lowerPad = ToINTS (src->Attributes ()[L" lowerPad" ].Value <NDShape>());
7088
7099
auto upperPad = ToINTS (src->Attributes ()[L" upperPad" ].Value <NDShape>());
7089
7100
7090
- if (IsPadValueValid (lowerPad, upperPad, autoPadding, ceilOutDim))
7101
+ // lowerPad and upperPad have incorrect dimension when the op has both batch and sequence axes.
7102
+ if (IsPadValueValid (lowerPad, upperPad, autoPadding, ceilOutDim) && !(src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ()))
7091
7103
{
7092
7104
if (ceilOutDim)
7093
7105
ValidatePadValueForCeilOutDim (lowerPad, upperPad, autoPadding, kernelShape, inputShape, strides,
@@ -7097,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
7097
7109
}
7098
7110
else
7099
7111
{
7112
+ if (src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ())
7113
+ {
7114
+ if (!std::all_of (lowerPad.begin (), lowerPad.end (), [](int64_t pad) {return pad == 0 ; }) ||
7115
+ !std::all_of (upperPad.begin (), upperPad.end (), [](int64_t pad) {return pad == 0 ; }))
7116
+ {
7117
+ fprintf (stderr, " Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes." );
7118
+ }
7119
+ }
7100
7120
if (isPooling)
7101
7121
PutPadAttrInNode (node, autoPadding, kernelShape, inputShape, strides, /* dilation=*/ std::vector<size_t >(kernelShape.Rank (), 1 ),
7102
7122
ceilOutDim, /* transpose=*/ !isPooling);
@@ -8605,6 +8625,54 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu
8605
8625
return activationNode;
8606
8626
}
8607
8627
8628
+ // insert reshape before and after a Pooling op when the CNTK op has both sequence and batch axes.
8629
+ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode (const FunctionPtr& src,
8630
+ onnxruntime::Graph* graph,
8631
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
8632
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
8633
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex)
8634
+ {
8635
+ if (!src->Inputs ()[0 ].HasBatchAxis () || !src->Inputs ()[0 ].HasSequenceAxis ())
8636
+ LogicError (" CreatePoolingNode is only to handle MaxPool with batch and sequence dimensions." );
8637
+
8638
+ std::vector<onnxruntime::NodeArg *> inputs;
8639
+ ProcessInputs (src, graph, functionNodes, variableNodes, inputs,
8640
+ scanLoops, createLoopIndex);
8641
+
8642
+ std::vector<onnxruntime::NodeArg *> outputs;
8643
+ ProcessOutputs (src, inputs, outputs, graph);
8644
+
8645
+ // Max/AveragePool takes input of shape [N, C, H, W] or [N, C, D1, D2, ..., Dn]. CNTK input needs to be reshaped to match it.
8646
+ // reshape [#, *][C, H, W] to [-1, C, H, W]
8647
+ // onnx Max/AveragePool
8648
+ // reshape [-1, C_out, H_out, W_out] to [#, *][C_out, H_out, W_out]
8649
+ vector<int64_t > newDimInputToPooling;
8650
+ // collapse extra dims into one axis as N for ONNX Conv
8651
+ newDimInputToPooling.push_back (-1 );
8652
+ for (int i = 2 ; i < inputs[0 ]->Shape ()->dim_size (); i++)
8653
+ {
8654
+ // copy C, H, W
8655
+ if (!inputs[0 ]->Shape ()->dim (i).has_dim_value ())
8656
+ LogicError (" Max/AveragePool: feature dimensions need to have dim value." );
8657
+ newDimInputToPooling.push_back (inputs[0 ]->Shape ()->dim (i).dim_value ());
8658
+ }
8659
+
8660
+ onnxruntime::Node* preReshape = AddReshapeNode (*inputs[0 ], newDimInputToPooling, inputs[0 ]->Name () + " _reshaped_for_max_pool" , graph);
8661
+ const std::vector<onnxruntime::NodeArg *> pooling_inputs ({const_cast <NodeArg *>(preReshape->OutputDefs ()[0 ])});
8662
+ TypeProto poolingOutputTypeProto;
8663
+ UpdateONNXType (src->Outputs ()[0 ].GetDataType (), poolingOutputTypeProto);
8664
+
8665
+ NodeArg *poolingOutputArg = &graph->GetOrCreateNodeArg (outputs[0 ]->Name () + " _pooling_of_reshaped" , &poolingOutputTypeProto);
8666
+
8667
+ onnxruntime::Node* poolingNode = AddNode (src, graph, pooling_inputs, { poolingOutputArg });
8668
+
8669
+ vector<int64_t > newDimOutputFromPooling = ToINTS (*outputs[0 ]->TypeAsProto ());
8670
+ onnxruntime::Node* postReshape = AddReshapeNode (*poolingOutputArg, newDimOutputFromPooling, outputs[0 ]->Name (), graph);
8671
+
8672
+ functionNodes.emplace (src, poolingNode);
8673
+ return postReshape;
8674
+ }
8675
+
8608
8676
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode (const FunctionPtr& src,
8609
8677
onnxruntime::Graph* graph,
8610
8678
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
0 commit comments