Skip to content

Commit ad1709c

Browse files
committed
fix reduce op keepdim setting
1 parent 248ff78 commit ad1709c

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7062,7 +7062,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
70627062
}
70637063

70647064
void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *node,
7065-
bool isSequenceReduceElement)
7065+
bool isSequenceReduceElement)
70667066
{
70677067
std::wstring reductionOpName = src->OpName();
70687068
if (reductionOpName == L"ReduceElements")
@@ -7071,23 +7071,20 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
70717071
}
70727072

70737073
//
7074-
int64_t keepReducedDimensions = 1;
7074+
int64_t keepReducedDimensions = 0;
70757075
if (src->Attributes().Contains(L"reductionKeepDimensions"))
70767076
keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value<bool>() ? 1 : 0);
7077-
bool forceKeepReducedDimensions = false;
7077+
else if (src->Inputs()[0].DynamicAxes().size() == src->Outputs()[0].DynamicAxes().size() &&
7078+
src->Inputs()[0].Shape().Rank() == src->Outputs()[0].Shape().Rank())
7079+
keepReducedDimensions = 1;
7080+
70787081

70797082
std::vector<Axis> reductionAxes;
70807083
if (src->Attributes().Contains(L"axisVec"))
70817084
reductionAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
70827085
else if (src->Attributes().Contains(L"axis"))
70837086
reductionAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
70847087

7085-
// Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
7086-
// For ONNX export we need to make sure we export keepdims as 0 (false).
7087-
// The same applies for AllStaticAxes.
7088-
if (!forceKeepReducedDimensions &&
7089-
(reductionAxes.size() == 1 && (reductionAxes[0] == Axis::DefaultBatchAxis() || reductionAxes[0] == Axis::AllStaticAxes() || reductionAxes[0] == Axis::AllAxes())))
7090-
keepReducedDimensions = 0;
70917088
std::vector<int64_t> axes = ConvertAxesToOnnx(reductionAxes, src->Inputs()[0]);
70927089

70937090
if (isSequenceReduceElement && axes.size() == 1 && axes[0] == 1 && src->Inputs()[0].DynamicAxes().size() == 1)
@@ -7100,8 +7097,8 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
71007097

71017098
if (reductionOpName == L"Argmax" || reductionOpName == L"Argmin")
71027099
node->AddAttribute("axis", axes[0]);
7103-
else
7104-
if (reductionAxes[0] != Axis::AllAxes())
7100+
else
7101+
if (reductionAxes[0] != Axis::AllAxes())
71057102
node->AddAttribute("axes", axes);
71067103

71077104
node->AddAttribute("keepdims", keepReducedDimensions);

0 commit comments

Comments
 (0)