Skip to content

Commit 1fe0bf1

Browse files
committed
reduction keep dim edge cases
1 parent ad1709c commit 1fe0bf1

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7070,21 +7070,27 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
70707070
reductionOpName = src->Attributes()[L"reductionOpName"].Value<wstring>();
70717071
}
70727072

7073+
std::vector<Axis> reductionAxes;
7074+
if (src->Attributes().Contains(L"axisVec"))
7075+
reductionAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
7076+
else if (src->Attributes().Contains(L"axis"))
7077+
reductionAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
7078+
70737079
//
70747080
int64_t keepReducedDimensions = 0;
70757081
if (src->Attributes().Contains(L"reductionKeepDimensions"))
70767082
keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value<bool>() ? 1 : 0);
7077-
else if (src->Inputs()[0].DynamicAxes().size() == src->Outputs()[0].DynamicAxes().size() &&
7083+
7084+
// there are cases where reductionKeepDimensions attribute does not take effect.
7085+
if((reductionAxes.size() == 1 && (reductionAxes[0] == Axis::DefaultBatchAxis() || reductionAxes[0] == Axis::AllStaticAxes() || reductionAxes[0] == Axis::AllAxes()))) // Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
7086+
// For ONNX export we need to make sure we export keepdims as 0 (false).
7087+
// The same applies for AllStaticAxes.
7088+
keepReducedDimensions = 0;
7089+
7090+
if (src->Inputs()[0].DynamicAxes().size() == src->Outputs()[0].DynamicAxes().size() &&
70787091
src->Inputs()[0].Shape().Rank() == src->Outputs()[0].Shape().Rank())
70797092
keepReducedDimensions = 1;
70807093

7081-
7082-
std::vector<Axis> reductionAxes;
7083-
if (src->Attributes().Contains(L"axisVec"))
7084-
reductionAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
7085-
else if (src->Attributes().Contains(L"axis"))
7086-
reductionAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
7087-
70887094
std::vector<int64_t> axes = ConvertAxesToOnnx(reductionAxes, src->Inputs()[0]);
70897095

70907096
if (isSequenceReduceElement && axes.size() == 1 && axes[0] == 1 && src->Inputs()[0].DynamicAxes().size() == 1)

0 commit comments

Comments
 (0)