@@ -845,6 +845,12 @@ class CNTKToONNXHelper
845
845
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
846
846
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
847
847
848
+ static onnxruntime::Node* CreatePoolingNode (const FunctionPtr& src,
849
+ onnxruntime::Graph* graph,
850
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
851
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
852
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex);
853
+
848
854
static onnxruntime::Node* CreateConvolutionNode (const FunctionPtr& src,
849
855
onnxruntime::Graph* graph,
850
856
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@@ -5396,6 +5402,9 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
5396
5402
std::string cntkOpName = ToLegacyString (ToUTF8 (src->OpName ()));
5397
5403
std::string onnxOpName = ToOPName (src);
5398
5404
5405
+ if (src->OpName () == L" Pooling" )
5406
+ std::cout << " " ;
5407
+
5399
5408
// TODO: uncomment this code once bidirectional LSTM is supprted.
5400
5409
// if (cntkOpName == "Splice")
5401
5410
// {
@@ -5629,7 +5638,10 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
5629
5638
else
5630
5639
return CreateConvolutionNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5631
5640
}
5632
-
5641
+ else if (src->OpName () == L" Pooling" && src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ())
5642
+ {
5643
+ return CreatePoolingNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5644
+ }
5633
5645
//
5634
5646
// If this block node equivalent to a primitive ONNX OP, then treated as such.
5635
5647
// And just maps its argument to ONNX node.
@@ -7087,7 +7099,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
7087
7099
auto lowerPad = ToINTS (src->Attributes ()[L" lowerPad" ].Value <NDShape>());
7088
7100
auto upperPad = ToINTS (src->Attributes ()[L" upperPad" ].Value <NDShape>());
7089
7101
7090
- if (IsPadValueValid (lowerPad, upperPad, autoPadding, ceilOutDim))
7102
+ if (IsPadValueValid (lowerPad, upperPad, autoPadding, ceilOutDim) && !(src-> Inputs ()[ 0 ]. HasBatchAxis () && src-> Inputs ()[ 0 ]. HasSequenceAxis ()) )
7091
7103
{
7092
7104
if (ceilOutDim)
7093
7105
ValidatePadValueForCeilOutDim (lowerPad, upperPad, autoPadding, kernelShape, inputShape, strides,
@@ -8605,6 +8617,52 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu
8605
8617
return activationNode;
8606
8618
}
8607
8619
8620
+ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode (const FunctionPtr& src,
8621
+ onnxruntime::Graph* graph,
8622
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
8623
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
8624
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex)
8625
+ {
8626
+ if (!src->Inputs ()[0 ].HasBatchAxis () || !src->Inputs ()[0 ].HasSequenceAxis ())
8627
+ LogicError (" CreatePoolingNode is only to handle MaxPool with batch and sequence dimensions." );
8628
+
8629
+ std::vector<onnxruntime::NodeArg *> inputs;
8630
+ ProcessInputs (src, graph, functionNodes, variableNodes, inputs,
8631
+ scanLoops, createLoopIndex);
8632
+
8633
+ std::vector<onnxruntime::NodeArg *> outputs;
8634
+ ProcessOutputs (src, inputs, outputs, graph);
8635
+
8636
+ // Max/AveragePool takes input of shape [N, C, H, W] or [N, C, D1, D2, ..., Dn]. CNTK input needs to be reshaped to match it.
8637
+ // reshape [#, *][C, H, W] to [-1, C, H, W]
8638
+ // onnx Max/AveragePool
8639
+ // reshape [-1, C_out, H_out, W_out] to [#, *][C_out, H_out, W_out]
8640
+ vector<int64_t > newDimInputToPooling;
8641
+ // collapse extra dims into one axis as N for ONNX Conv
8642
+ newDimInputToPooling.push_back (-1 );
8643
+ for (int i = 2 ; i < inputs[0 ]->Shape ()->dim_size (); i++)
8644
+ {
8645
+ // copy C, H, W
8646
+ if (!inputs[0 ]->Shape ()->dim (i).has_dim_value ())
8647
+ LogicError (" Max/AveragePool: feature dimensions need to have dim value." );
8648
+ newDimInputToPooling.push_back (inputs[0 ]->Shape ()->dim (i).dim_value ());
8649
+ }
8650
+
8651
+ onnxruntime::Node* preReshape = AddReshapeNode (*inputs[0 ], newDimInputToPooling, inputs[0 ]->Name () + " _reshaped_for_max_pool" , graph);
8652
+ const std::vector<onnxruntime::NodeArg *> pooling_inputs ({const_cast <NodeArg *>(preReshape->OutputDefs ()[0 ])});
8653
+ TypeProto poolingOutputTypeProto;
8654
+ UpdateONNXType (src->Outputs ()[0 ].GetDataType (), poolingOutputTypeProto);
8655
+
8656
+ NodeArg *poolingOutputArg = &graph->GetOrCreateNodeArg (outputs[0 ]->Name () + " _pooling_of_reshaped" , &poolingOutputTypeProto);
8657
+
8658
+ onnxruntime::Node* poolingNode = AddNode (src, graph, pooling_inputs, { poolingOutputArg });
8659
+
8660
+ vector<int64_t > newDimOutputFromPooling = ToINTS (*outputs[0 ]->TypeAsProto ());
8661
+ onnxruntime::Node* postReshape = AddReshapeNode (*poolingOutputArg, newDimOutputFromPooling, outputs[0 ]->Name (), graph);
8662
+
8663
+ return poolingNode;
8664
+ }
8665
+
8608
8666
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode (const FunctionPtr& src,
8609
8667
onnxruntime::Graph* graph,
8610
8668
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
0 commit comments