@@ -7070,21 +7070,27 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
7070
7070
reductionOpName = src->Attributes ()[L" reductionOpName" ].Value <wstring>();
7071
7071
}
7072
7072
7073
+ std::vector<Axis> reductionAxes;
7074
+ if (src->Attributes ().Contains (L" axisVec" ))
7075
+ reductionAxes = AsVector<Axis>(src->Attributes ()[L" axisVec" ].Value <std::vector<DictionaryValue>>());
7076
+ else if (src->Attributes ().Contains (L" axis" ))
7077
+ reductionAxes.push_back ((Axis)(src->Attributes ()[L" axis" ].Value <Axis>()));
7078
+
7073
7079
//
7074
7080
int64_t keepReducedDimensions = 0 ;
7075
7081
if (src->Attributes ().Contains (L" reductionKeepDimensions" ))
7076
7082
keepReducedDimensions = (int64_t )((bool )src->Attributes ()[L" reductionKeepDimensions" ].Value <bool >() ? 1 : 0 );
7077
- else if (src->Inputs ()[0 ].DynamicAxes ().size () == src->Outputs ()[0 ].DynamicAxes ().size () &&
7083
+
7084
+ // there are cases where reductionKeepDimensions attribute does not take effect.
7085
+ if ((reductionAxes.size () == 1 && (reductionAxes[0 ] == Axis::DefaultBatchAxis () || reductionAxes[0 ] == Axis::AllStaticAxes () || reductionAxes[0 ] == Axis::AllAxes ()))) // Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
7086
+ // For ONNX export we need to make sure we export keepdims as 0 (false).
7087
+ // The same applies for AllStaticAxes.
7088
+ keepReducedDimensions = 0 ;
7089
+
7090
+ if (src->Inputs ()[0 ].DynamicAxes ().size () == src->Outputs ()[0 ].DynamicAxes ().size () &&
7078
7091
src->Inputs ()[0 ].Shape ().Rank () == src->Outputs ()[0 ].Shape ().Rank ())
7079
7092
keepReducedDimensions = 1 ;
7080
7093
7081
-
7082
- std::vector<Axis> reductionAxes;
7083
- if (src->Attributes ().Contains (L" axisVec" ))
7084
- reductionAxes = AsVector<Axis>(src->Attributes ()[L" axisVec" ].Value <std::vector<DictionaryValue>>());
7085
- else if (src->Attributes ().Contains (L" axis" ))
7086
- reductionAxes.push_back ((Axis)(src->Attributes ()[L" axis" ].Value <Axis>()));
7087
-
7088
7094
std::vector<int64_t > axes = ConvertAxesToOnnx (reductionAxes, src->Inputs ()[0 ]);
7089
7095
7090
7096
if (isSequenceReduceElement && axes.size () == 1 && axes[0 ] == 1 && src->Inputs ()[0 ].DynamicAxes ().size () == 1 )
0 commit comments