Skip to content

Commit 22a86bf

Browse files
committed
onnx export: fix cases when scan input needs to be broadcasted
1 parent 0e172db commit 22a86bf

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ class CNTKToONNXHelper
388388
static onnxruntime::Node *AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector<int64_t>& newShape);
389389

390390
static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex,
391-
onnx::TypeProto &inputArgType);
391+
onnx::TypeProto &inputArgType, const std::string& scanInputName = "");
392392

393393
static int BatchSizeOverride(const FunctionPtr src, const std::vector<onnxruntime::NodeArg*>& inputs,
394394
onnx::TypeProto& outputArgType);
@@ -5596,8 +5596,12 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
55965596
// z = C.sequence.input_variable((2,)) + C.input_variable((3,2))
55975597
//
55985598
// input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers.
5599-
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
5600-
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType)
5599+
//
5600+
// Special note for scan input cases:
5601+
// when input is also an input to a scan op, we have to keep its name so that
5602+
// the main graph and subgraph is well connected.
5603+
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
5604+
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::string& scanInputName)
56015605
{
56025606
// TODO: do we need to get blockroot if it is a block function?
56035607
if (!Operators::SupportBroadcast(src->OpName()))
@@ -5653,11 +5657,16 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr
56535657
//inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type());
56545658
//UpdateONNXType(input.GetDataType(), inputArgType);
56555659
std::string inputNodeArgName;
5656-
auto inputItr = compositeOutputsMap.find(input);
5657-
if (inputItr != compositeOutputsMap.end())
5658-
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
5660+
if (scanInputName != "")
5661+
inputNodeArgName = scanInputName;
56595662
else
5660-
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
5663+
{
5664+
auto inputItr = compositeOutputsMap.find(input);
5665+
if (inputItr != compositeOutputsMap.end())
5666+
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
5667+
else
5668+
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
5669+
}
56615670

56625671
std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeArgName + "_reshaped_for_broadcast");
56635672
onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg(inputNodeArgName, &inputArgType);
@@ -6095,11 +6104,18 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
60956104
}
60966105
}
60976106

6098-
onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
6099-
6107+
onnxruntime::NodeArg *adjusted = nullptr;
61006108
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
61016109
{
61026110
inputName = MakeScanInputOutputNodeArgName(inputName);
6111+
6112+
// in case of broadcast, we want the input name unchanged.
6113+
// The inserted reshape op is treated as being inside of the scan subgraph.
6114+
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType, inputName);
6115+
}
6116+
else
6117+
{
6118+
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
61036119
}
61046120

61056121
onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;

0 commit comments

Comments
 (0)