Skip to content

Commit 80e9f79

Browse files
committed
update with reviewers' comments
1 parent 9a7dd4c commit 80e9f79

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5637,7 +5637,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
56375637
}
56385638
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
56395639
{
5640-
// in case a Pooling op is created with bother batch and sequence axes, we need to reshape its input and output to match
5640+
// in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
56415641
// ONNX spec of [N, C, H, W] shape requirement.
56425642
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
56435643
}
@@ -7109,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
71097109
}
71107110
else
71117111
{
7112+
if (src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
7113+
{
7114+
if (!std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t pad) {return pad == 0; }) ||
7115+
!std::all_of(upperPad.begin(), upperPad.end(), [](int64_t pad) {return pad == 0; }))
7116+
{
7117+
fprintf(stderr, "Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes.");
7118+
}
7119+
}
71127120
if (isPooling)
71137121
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1),
71147122
ceilOutDim, /*transpose=*/!isPooling);
@@ -8661,7 +8669,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
86618669
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
86628670
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
86638671

8664-
return poolingNode;
8672+
functionNodes.emplace(src, poolingNode);
8673+
return postReshape;
86658674
}
86668675

86678676
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,

0 commit comments

Comments
 (0)