@@ -5637,7 +5637,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
5637
5637
}
5638
5638
else if (src->OpName () == L" Pooling" && src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ())
5639
5639
{
5640
- // in case a Pooling op is created with bother batch and sequence axes, we need to reshape its input and output to match
5640
+ // in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
5641
5641
// ONNX spec of [N, C, H, W] shape requirement.
5642
5642
return CreatePoolingNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5643
5643
}
@@ -7109,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
7109
7109
}
7110
7110
else
7111
7111
{
7112
+ if (src->Inputs ()[0 ].HasBatchAxis () && src->Inputs ()[0 ].HasSequenceAxis ())
7113
+ {
7114
+ if (!std::all_of (lowerPad.begin (), lowerPad.end (), [](int64_t pad) {return pad == 0 ; }) ||
7115
+ !std::all_of (upperPad.begin (), upperPad.end (), [](int64_t pad) {return pad == 0 ; }))
7116
+ {
7117
+ fprintf (stderr, " Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes." );
7118
+ }
7119
+ }
7112
7120
if (isPooling)
7113
7121
PutPadAttrInNode (node, autoPadding, kernelShape, inputShape, strides, /* dilation=*/ std::vector<size_t >(kernelShape.Rank (), 1 ),
7114
7122
ceilOutDim, /* transpose=*/ !isPooling);
@@ -8661,7 +8669,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
8661
8669
vector<int64_t > newDimOutputFromPooling = ToINTS (*outputs[0 ]->TypeAsProto ());
8662
8670
onnxruntime::Node* postReshape = AddReshapeNode (*poolingOutputArg, newDimOutputFromPooling, outputs[0 ]->Name (), graph);
8663
8671
8664
- return poolingNode;
8672
+ functionNodes.emplace (src, poolingNode);
8673
+ return postReshape;
8665
8674
}
8666
8675
8667
8676
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode (const FunctionPtr& src,
0 commit comments