@@ -7062,7 +7062,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
7062
7062
}
7063
7063
7064
7064
void CNTKToONNXHelper::SetReduceElementsAttributes (const FunctionPtr src, Node *node,
7065
- bool isSequenceReduceElement)
7065
+ bool isSequenceReduceElement)
7066
7066
{
7067
7067
std::wstring reductionOpName = src->OpName ();
7068
7068
if (reductionOpName == L" ReduceElements" )
@@ -7071,23 +7071,20 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
7071
7071
}
7072
7072
7073
7073
//
7074
- int64_t keepReducedDimensions = 1 ;
7074
+ int64_t keepReducedDimensions = 0 ;
7075
7075
if (src->Attributes ().Contains (L" reductionKeepDimensions" ))
7076
7076
keepReducedDimensions = (int64_t )((bool )src->Attributes ()[L" reductionKeepDimensions" ].Value <bool >() ? 1 : 0 );
7077
- bool forceKeepReducedDimensions = false ;
7077
+ else if (src->Inputs ()[0 ].DynamicAxes ().size () == src->Outputs ()[0 ].DynamicAxes ().size () &&
7078
+ src->Inputs ()[0 ].Shape ().Rank () == src->Outputs ()[0 ].Shape ().Rank ())
7079
+ keepReducedDimensions = 1 ;
7080
+
7078
7081
7079
7082
std::vector<Axis> reductionAxes;
7080
7083
if (src->Attributes ().Contains (L" axisVec" ))
7081
7084
reductionAxes = AsVector<Axis>(src->Attributes ()[L" axisVec" ].Value <std::vector<DictionaryValue>>());
7082
7085
else if (src->Attributes ().Contains (L" axis" ))
7083
7086
reductionAxes.push_back ((Axis)(src->Attributes ()[L" axis" ].Value <Axis>()));
7084
7087
7085
- // Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
7086
- // For ONNX export we need to make sure we export keepdims as 0 (false).
7087
- // The same applies for AllStaticAxes.
7088
- if (!forceKeepReducedDimensions &&
7089
- (reductionAxes.size () == 1 && (reductionAxes[0 ] == Axis::DefaultBatchAxis () || reductionAxes[0 ] == Axis::AllStaticAxes () || reductionAxes[0 ] == Axis::AllAxes ())))
7090
- keepReducedDimensions = 0 ;
7091
7088
std::vector<int64_t > axes = ConvertAxesToOnnx (reductionAxes, src->Inputs ()[0 ]);
7092
7089
7093
7090
if (isSequenceReduceElement && axes.size () == 1 && axes[0 ] == 1 && src->Inputs ()[0 ].DynamicAxes ().size () == 1 )
@@ -7100,8 +7097,8 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
7100
7097
7101
7098
if (reductionOpName == L" Argmax" || reductionOpName == L" Argmin" )
7102
7099
node->AddAttribute (" axis" , axes[0 ]);
7103
- else
7104
- if (reductionAxes[0 ] != Axis::AllAxes ())
7100
+ else
7101
+ if (reductionAxes[0 ] != Axis::AllAxes ())
7105
7102
node->AddAttribute (" axes" , axes);
7106
7103
7107
7104
node->AddAttribute (" keepdims" , keepReducedDimensions);
0 commit comments