@@ -388,7 +388,7 @@ class CNTKToONNXHelper
388
388
static onnxruntime::Node *AddReshapeNodeImpl (Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector<int64_t >& newShape);
389
389
390
390
static NodeArg* GetInputAdjustmentForBroadcast (onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex,
391
- onnx::TypeProto &inputArgType);
391
+ onnx::TypeProto &inputArgType, const std::string& scanInputName = " " );
392
392
393
393
static int BatchSizeOverride (const FunctionPtr src, const std::vector<onnxruntime::NodeArg*>& inputs,
394
394
onnx::TypeProto& outputArgType);
@@ -5596,8 +5596,12 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
5596
5596
// z = C.sequence.input_variable((2,)) + C.input_variable((3,2))
5597
5597
//
5598
5598
// input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers.
5599
- NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast (onnxruntime::Graph* graph, const FunctionPtr src,
5600
- const Variable &input, int inputIndex, onnx::TypeProto &inputArgType)
5599
+ //
5600
+ // Special note for scan input cases:
5601
+ // when input is also an input to a scan op, we have to keep its name so that
5602
+ // the main graph and subgraph is well connected.
5603
+ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast (onnxruntime::Graph* graph, const FunctionPtr src,
5604
+ const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::string& scanInputName)
5601
5605
{
5602
5606
// TODO: do we need to get blockroot if it is a block function?
5603
5607
if (!Operators::SupportBroadcast (src->OpName ()))
@@ -5653,11 +5657,16 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr
5653
5657
// inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type());
5654
5658
// UpdateONNXType(input.GetDataType(), inputArgType);
5655
5659
std::string inputNodeArgName;
5656
- auto inputItr = compositeOutputsMap.find (input);
5657
- if (inputItr != compositeOutputsMap.end ())
5658
- inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName (inputItr->second );
5660
+ if (scanInputName != " " )
5661
+ inputNodeArgName = scanInputName;
5659
5662
else
5660
- inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName (input);
5663
+ {
5664
+ auto inputItr = compositeOutputsMap.find (input);
5665
+ if (inputItr != compositeOutputsMap.end ())
5666
+ inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName (inputItr->second );
5667
+ else
5668
+ inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName (input);
5669
+ }
5661
5670
5662
5671
std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid (inputNodeArgName + " _reshaped_for_broadcast" );
5663
5672
onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg (inputNodeArgName, &inputArgType);
@@ -6095,11 +6104,18 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
6095
6104
}
6096
6105
}
6097
6106
6098
- onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast (graph, src, input, inputIndex, inputArgType);
6099
-
6107
+ onnxruntime::NodeArg *adjusted = nullptr ;
6100
6108
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
6101
6109
{
6102
6110
inputName = MakeScanInputOutputNodeArgName (inputName);
6111
+
6112
+ // in case of broadcast, we want the input name unchanged.
6113
+ // The inserted reshape op is treated as being inside of the scan subgraph.
6114
+ adjusted = GetInputAdjustmentForBroadcast (graph, src, input, inputIndex, inputArgType, inputName);
6115
+ }
6116
+ else
6117
+ {
6118
+ adjusted = GetInputAdjustmentForBroadcast (graph, src, input, inputIndex, inputArgType);
6103
6119
}
6104
6120
6105
6121
onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg (inputName, &inputArgType) : *adjusted;
0 commit comments