Skip to content

Commit 248ff78

Browse files
committed
onnx exporter for past/future value: fix for offset > sequence length cases; output delayed initial values
1 parent 8ed7050 commit 248ff78

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4047,43 +4047,27 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
40474047
bool past = src->OpName() == L"PastValue";
40484048
size_t offset = src->Attributes()[L"offset"].Value<size_t>();
40494049

4050-
// 1. slice off first or last timeframe from input[0] -> input_sliced_node
4051-
// 2. expand initial value input[1] to the shape of input[0] without sequence axis (the first axis) -> init_value_expanded
4052-
// 3. concat input_sliced_node with init_value_expanded or other way around -> Past(Future)Value node
4050+
// 1. expand initial value input[1] to the shape of input[0] with sequence dim equals offset -> init_value_expanded
4051+
// 2. concat input with init_value_expanded or the other way around -> concatNode
4052+
// 3. slice concatNode -> sliceNode and delayedInitialValues
40534053

4054-
// 1. slice input
4055-
int64_t sliceAxis = 0, sliceStart, sliceEnd;
4056-
if (past)
4057-
{
4058-
sliceStart = 0;
4059-
sliceEnd = -offset;
4060-
}
4061-
else
4062-
{
4063-
sliceStart = offset;
4064-
sliceEnd = std::numeric_limits<int64_t>::max();
4065-
}
4066-
4067-
const std::string sliceOutputArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(src->Inputs()[0]) +
4068-
"_slice_" + UniqueNodeNameStorage::GetUniqueNodeName(src);
4069-
Node* sliceNode = AddSliceNode(*inputs[0], {sliceAxis}, {sliceStart}, {sliceEnd}, sliceOutputArgName, graph);
4070-
4071-
// 2. expand init_value
4054+
// 1. expand initial value
40724055
std::vector<int64_t> expandShape = ToINTS(*inputs[0]->TypeAsProto());
40734056
// sequence dimension is offset for init_value
40744057
expandShape[0] = offset;
40754058
const std::string expandOutputName = UniqueNodeNameStorage::GetUniqueInputNodeName(src->Inputs()[1]) + "_expand_" +
40764059
UniqueNodeNameStorage::GetUniqueNodeName(src);
40774060
Node* initValueExpand = AddExpandNode(*inputs[1], expandShape, expandOutputName, graph);
40784061

4079-
// 3. concat
4080-
std::string outputNodeArgName = UniqueNodeNameStorage::GetUniqueOutputNodeName(src->Outputs()[0]);
4062+
// 2. concat
4063+
std::string outputName = UniqueNodeNameStorage::GetUniqueOutputNodeName(src->Outputs()[0]);
4064+
std::string concatOutputNodeArgName = outputName + "_concated_preslice";
40814065

40824066
Node* concatNode;
40834067
// there are cases where input is a scaler and step function turns output to a dim 1 vector [#, *]() -> [#, *](1)
40844068
// in this case, Concat shall output with original shape (#, *) and followed by an expand to get the matching shape [*, #](1).
40854069
bool pastFutureValueOutputDoesNotMatchInputByOne =
4086-
sliceNode->OutputDefs()[0]->Shape()->dim_size() == 2 &&
4070+
inputs[0]->Shape()->dim_size() == 2 &&
40874071
outputs[0]->Shape()->dim_size() == 3 &&
40884072
outputs[0]->Shape()->dim(2).dim_value() == 1;
40894073
std::vector<onnxruntime::NodeArg *> concatOutputs;
@@ -4098,22 +4082,26 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
40984082
typeProto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(outputs[0]->Shape()->dim(i).dim_value());
40994083
}
41004084
UpdateONNXType(src->Output().GetDataType(), typeProto);
4101-
NodeArg *arg = &graph->GetOrCreateNodeArg(outputNodeArgName + "_before_expand", &typeProto);
4085+
NodeArg *arg = &graph->GetOrCreateNodeArg(concatOutputNodeArgName + "_before_expand", &typeProto);
41024086
concatOutputs.push_back(arg);
41034087
}
41044088
else
4105-
concatOutputs = outputs;
4106-
4089+
{
4090+
TypeProto concatOutputTypeProto;
4091+
UpdateONNXType(src->Outputs()[0].GetDataType(), concatOutputTypeProto);
4092+
onnxruntime::NodeArg *concatOutputArg = &graph->GetOrCreateNodeArg(concatOutputNodeArgName, &concatOutputTypeProto);
4093+
concatOutputs.push_back(concatOutputArg);
4094+
}
41074095
if (past)
41084096
{
41094097
concatNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
4110-
{const_cast<NodeArg*>(initValueExpand->OutputDefs()[0]), const_cast<NodeArg*>(sliceNode->OutputDefs()[0])},
4098+
{const_cast<NodeArg*>(initValueExpand->OutputDefs()[0]), const_cast<NodeArg*>(inputs[0])},
41114099
concatOutputs);
41124100
}
41134101
else
41144102
{
41154103
concatNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
4116-
{const_cast<NodeArg*>(sliceNode->OutputDefs()[0]), const_cast<NodeArg*>(initValueExpand->OutputDefs()[0])},
4104+
{const_cast<NodeArg*>(inputs[0]), const_cast<NodeArg*>(initValueExpand->OutputDefs()[0])},
41174105
concatOutputs);
41184106
}
41194107
// concat on sequence axis
@@ -4122,16 +4110,53 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
41224110
if (pastFutureValueOutputDoesNotMatchInputByOne)
41234111
{
41244112
std::vector<int64_t> newShape = ToINTS(*outputs[0]->TypeAsProto());
4125-
Node* expandNode = AddReshapeNode(const_cast<NodeArg &>(*concatNode->OutputDefs()[0]),
4113+
concatNode = AddReshapeNode(const_cast<NodeArg &>(*concatNode->OutputDefs()[0]),
41264114
newShape, outputs[0]->Name(), graph);
4127-
functionNodes.emplace(src, expandNode);
4128-
return expandNode;
4115+
}
4116+
4117+
int64_t sliceAxis = 0, sliceStart, sliceEnd;
4118+
if (past)
4119+
{
4120+
sliceStart = 0;
4121+
sliceEnd = -offset;
4122+
}
4123+
else
4124+
{
4125+
sliceStart = offset;
4126+
sliceEnd = std::numeric_limits<int64_t>::max();
4127+
}
4128+
4129+
Node* sliceNode = AddSliceNode(const_cast<NodeArg &>(*concatNode->OutputDefs()[0]),
4130+
{ sliceAxis }, { sliceStart }, { sliceEnd }, outputName, graph);
4131+
functionNodes.emplace(src, sliceNode);
4132+
4133+
// for segmented sequence to run in a consecutive way - delayed values are taken from previous runs,
4134+
// we need to output rest of sliced frames
4135+
if (past)
4136+
{
4137+
sliceStart = -offset;
4138+
sliceEnd = std::numeric_limits<int64_t>::max();
41294139
}
41304140
else
41314141
{
4132-
functionNodes.emplace(src, concatNode);
4133-
return concatNode;
4142+
sliceStart = 0;
4143+
sliceEnd = offset;
41344144
}
4145+
4146+
// get the delayed initial value shape right
4147+
std::vector<int64_t> delayedInitialValuesOutputShape = ToINTS(*sliceNode->OutputDefs()[0]->TypeAsProto());
4148+
delayedInitialValuesOutputShape[0] = offset;
4149+
TypeProto delayedInitialValuesNodeArgType = ToTypeProto(delayedInitialValuesOutputShape, false);
4150+
UpdateONNXType(src->Output().GetDataType(), delayedInitialValuesNodeArgType);
4151+
4152+
std::string delayedInitialValuesNodeArgName = outputName + "_delayed_initial_values";
4153+
NodeArg* delayedInitialValuesNodeArg = &graph->GetOrCreateNodeArg(
4154+
delayedInitialValuesNodeArgName, &delayedInitialValuesNodeArgType);
4155+
4156+
Node* delayedInitialValues = AddSliceNode(const_cast<NodeArg &>(*concatNode->OutputDefs()[0]),
4157+
{ sliceAxis }, { sliceStart }, { sliceEnd }, delayedInitialValuesNodeArgName, graph);
4158+
4159+
return sliceNode;
41354160
}
41364161

41374162
// the idea is to create an EyeLike node and slice the first slice for IsFirst, the last slice for IsLast op.

0 commit comments

Comments
 (0)