Skip to content

Commit 8ed7050

Browse files
committed
ONNX convolution with more than 1 non-feature/channel dimensions
1 parent 22a86bf commit 8ed7050

File tree

1 file changed

+67
-1
lines changed

1 file changed

+67
-1
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,12 @@ class CNTKToONNXHelper
819819
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
820820
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
821821

822+
static onnxruntime::Node* CreateConvolutionNode(const FunctionPtr& src,
823+
onnxruntime::Graph* graph,
824+
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
825+
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
826+
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
827+
822828
static onnxruntime::Node* CreateConvolutionBlockNode(const FunctionPtr& src,
823829
onnxruntime::Graph* graph,
824830
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@@ -5511,9 +5517,12 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
55115517
{
55125518
return CreateONNXNodesForFlatten(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
55135519
}
5514-
else if (cntkOpName == "Convolution" && src->IsBlock())
5520+
else if (cntkOpName == "Convolution")
55155521
{
5522+
if (src->IsBlock())
55165523
return CreateConvolutionBlockNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5524+
else
5525+
return CreateConvolutionNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
55175526
}
55185527

55195528
//
@@ -8391,6 +8400,63 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu
83918400
return activationNode;
83928401
}
83938402

8403+
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,
8404+
onnxruntime::Graph* graph,
8405+
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
8406+
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
8407+
std::vector<ScanLoop>& scanLoops, int createLoopIndex)
8408+
{
8409+
std::vector<onnxruntime::NodeArg *> inputs;
8410+
ProcessInputs(src, graph, functionNodes, variableNodes, inputs,
8411+
scanLoops, createLoopIndex);
8412+
8413+
std::vector<onnxruntime::NodeArg *> outputs;
8414+
ProcessOutputs(src, inputs, outputs, graph);
8415+
8416+
int featureRank = inputs[0]->Shape()->dim_size() - 2;
8417+
int inputRank = inputs[1]->Shape()->dim_size();
8418+
8419+
int extra_dims = inputRank - featureRank - 1;
8420+
onnxruntime::Node* convNode = nullptr;
8421+
if (extra_dims > 1)
8422+
{
8423+
// need to reshape input to fit ONNX spec (N, C, Features),
8424+
// applying Conv, then reshape back to the real output shape.
8425+
vector<int64_t> newDimInputToConv;
8426+
// collapse extra dims into one axis as N for ONNX Conv
8427+
newDimInputToConv.push_back(-1);
8428+
for (int i = extra_dims; i < inputRank; i++)
8429+
{
8430+
// copy channel and feature dimensions
8431+
if (!inputs[1]->Shape()->dim(i).has_dim_value())
8432+
LogicError("Convolution: feature dimensions need to have dim value.");
8433+
newDimInputToConv.push_back(inputs[1]->Shape()->dim(i).dim_value());
8434+
}
8435+
8436+
onnxruntime::Node* preReshape = AddReshapeNode(*inputs[1], newDimInputToConv, inputs[1]->Name() + "_reshaped_for_conv", graph);
8437+
std::vector<onnxruntime::NodeArg *> conv_inputs = inputs;
8438+
conv_inputs[1] = const_cast<NodeArg *>(preReshape->OutputDefs()[0]);
8439+
TypeProto convOutputTypeProto;
8440+
UpdateONNXType(src->Outputs()[0].GetDataType(), convOutputTypeProto);
8441+
8442+
NodeArg *convOutputArg = &graph->GetOrCreateNodeArg(outputs[0]->Name() + "_conv_of_reshaped", &convOutputTypeProto);
8443+
8444+
convNode = AddNode(src, graph, conv_inputs, { convOutputArg });
8445+
8446+
vector<int64_t> newDimOutputFromConv = ToINTS(*outputs[0]->TypeAsProto());
8447+
onnxruntime::Node* postReshape = AddReshapeNode(*convOutputArg, newDimOutputFromConv, outputs[0]->Name(), graph);
8448+
}
8449+
else
8450+
{
8451+
if (extra_dims != 1)
8452+
LogicError("Convolution op with incorrect input dumensions.");
8453+
convNode = AddNode(src, graph, inputs, outputs);
8454+
}
8455+
8456+
functionNodes.emplace(src, convNode);
8457+
return convNode;
8458+
}
8459+
83948460
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionBlockNode(const FunctionPtr &src,
83958461
onnxruntime::Graph* graph,
83968462
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,

0 commit comments

Comments
 (0)