@@ -819,6 +819,12 @@ class CNTKToONNXHelper
819
819
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
820
820
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
821
821
822
+ static onnxruntime::Node* CreateConvolutionNode (const FunctionPtr& src,
823
+ onnxruntime::Graph* graph,
824
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
825
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
826
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex);
827
+
822
828
static onnxruntime::Node* CreateConvolutionBlockNode (const FunctionPtr& src,
823
829
onnxruntime::Graph* graph,
824
830
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@@ -5511,9 +5517,12 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
5511
5517
{
5512
5518
return CreateONNXNodesForFlatten (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5513
5519
}
5514
- else if (cntkOpName == " Convolution" && src-> IsBlock () )
5520
+ else if (cntkOpName == " Convolution" )
5515
5521
{
5522
+ if (src->IsBlock ())
5516
5523
return CreateConvolutionBlockNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5524
+ else
5525
+ return CreateConvolutionNode (src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5517
5526
}
5518
5527
5519
5528
//
@@ -8391,6 +8400,63 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu
8391
8400
return activationNode;
8392
8401
}
8393
8402
8403
+ onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode (const FunctionPtr& src,
8404
+ onnxruntime::Graph* graph,
8405
+ std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
8406
+ std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
8407
+ std::vector<ScanLoop>& scanLoops, int createLoopIndex)
8408
+ {
8409
+ std::vector<onnxruntime::NodeArg *> inputs;
8410
+ ProcessInputs (src, graph, functionNodes, variableNodes, inputs,
8411
+ scanLoops, createLoopIndex);
8412
+
8413
+ std::vector<onnxruntime::NodeArg *> outputs;
8414
+ ProcessOutputs (src, inputs, outputs, graph);
8415
+
8416
+ int featureRank = inputs[0 ]->Shape ()->dim_size () - 2 ;
8417
+ int inputRank = inputs[1 ]->Shape ()->dim_size ();
8418
+
8419
+ int extra_dims = inputRank - featureRank - 1 ;
8420
+ onnxruntime::Node* convNode = nullptr ;
8421
+ if (extra_dims > 1 )
8422
+ {
8423
+ // need to reshape input to fit ONNX spec (N, C, Features),
8424
+ // applying Conv, then reshape back to the real output shape.
8425
+ vector<int64_t > newDimInputToConv;
8426
+ // collapse extra dims into one axis as N for ONNX Conv
8427
+ newDimInputToConv.push_back (-1 );
8428
+ for (int i = extra_dims; i < inputRank; i++)
8429
+ {
8430
+ // copy channel and feature dimensions
8431
+ if (!inputs[1 ]->Shape ()->dim (i).has_dim_value ())
8432
+ LogicError (" Convolution: feature dimensions need to have dim value." );
8433
+ newDimInputToConv.push_back (inputs[1 ]->Shape ()->dim (i).dim_value ());
8434
+ }
8435
+
8436
+ onnxruntime::Node* preReshape = AddReshapeNode (*inputs[1 ], newDimInputToConv, inputs[1 ]->Name () + " _reshaped_for_conv" , graph);
8437
+ std::vector<onnxruntime::NodeArg *> conv_inputs = inputs;
8438
+ conv_inputs[1 ] = const_cast <NodeArg *>(preReshape->OutputDefs ()[0 ]);
8439
+ TypeProto convOutputTypeProto;
8440
+ UpdateONNXType (src->Outputs ()[0 ].GetDataType (), convOutputTypeProto);
8441
+
8442
+ NodeArg *convOutputArg = &graph->GetOrCreateNodeArg (outputs[0 ]->Name () + " _conv_of_reshaped" , &convOutputTypeProto);
8443
+
8444
+ convNode = AddNode (src, graph, conv_inputs, { convOutputArg });
8445
+
8446
+ vector<int64_t > newDimOutputFromConv = ToINTS (*outputs[0 ]->TypeAsProto ());
8447
+ onnxruntime::Node* postReshape = AddReshapeNode (*convOutputArg, newDimOutputFromConv, outputs[0 ]->Name (), graph);
8448
+ }
8449
+ else
8450
+ {
8451
+ if (extra_dims != 1 )
8452
+ LogicError (" Convolution op with incorrect input dumensions." );
8453
+ convNode = AddNode (src, graph, inputs, outputs);
8454
+ }
8455
+
8456
+ functionNodes.emplace (src, convNode);
8457
+ return convNode;
8458
+ }
8459
+
8394
8460
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionBlockNode (const FunctionPtr &src,
8395
8461
onnxruntime::Graph* graph,
8396
8462
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
0 commit comments