@@ -4047,43 +4047,27 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
4047
4047
bool past = src->OpName () == L" PastValue" ;
4048
4048
size_t offset = src->Attributes ()[L" offset" ].Value <size_t >();
4049
4049
4050
- // 1. slice off first or last timeframe from input[0] -> input_sliced_node
4051
- // 2. expand initial value input[1] to the shape of input[0] without sequence axis ( the first axis) -> init_value_expanded
4052
- // 3. concat input_sliced_node with init_value_expanded or other way around -> Past(Future)Value node
4050
+ // 1. expand initial value input[1] to the shape of input[0] with sequence dim equals offset -> init_value_expanded
4051
+ // 2. concat input with init_value_expanded or the other way around -> concatNode
4052
+ // 3. slice concatNode -> sliceNode and delayedInitialValues
4053
4053
4054
- // 1. slice input
4055
- int64_t sliceAxis = 0 , sliceStart, sliceEnd;
4056
- if (past)
4057
- {
4058
- sliceStart = 0 ;
4059
- sliceEnd = -offset;
4060
- }
4061
- else
4062
- {
4063
- sliceStart = offset;
4064
- sliceEnd = std::numeric_limits<int64_t >::max ();
4065
- }
4066
-
4067
- const std::string sliceOutputArgName = UniqueNodeNameStorage::GetUniqueInputNodeName (src->Inputs ()[0 ]) +
4068
- " _slice_" + UniqueNodeNameStorage::GetUniqueNodeName (src);
4069
- Node* sliceNode = AddSliceNode (*inputs[0 ], {sliceAxis}, {sliceStart}, {sliceEnd}, sliceOutputArgName, graph);
4070
-
4071
- // 2. expand init_value
4054
+ // 1. expand initial value
4072
4055
std::vector<int64_t > expandShape = ToINTS (*inputs[0 ]->TypeAsProto ());
4073
4056
// sequence dimension is offset for init_value
4074
4057
expandShape[0 ] = offset;
4075
4058
const std::string expandOutputName = UniqueNodeNameStorage::GetUniqueInputNodeName (src->Inputs ()[1 ]) + " _expand_" +
4076
4059
UniqueNodeNameStorage::GetUniqueNodeName (src);
4077
4060
Node* initValueExpand = AddExpandNode (*inputs[1 ], expandShape, expandOutputName, graph);
4078
4061
4079
- // 3. concat
4080
- std::string outputNodeArgName = UniqueNodeNameStorage::GetUniqueOutputNodeName (src->Outputs ()[0 ]);
4062
+ // 2. concat
4063
+ std::string outputName = UniqueNodeNameStorage::GetUniqueOutputNodeName (src->Outputs ()[0 ]);
4064
+ std::string concatOutputNodeArgName = outputName + " _concated_preslice" ;
4081
4065
4082
4066
Node* concatNode;
4083
4067
// there are cases where input is a scaler and step function turns output to a dim 1 vector [#, *]() -> [#, *](1)
4084
4068
// in this case, Concat shall output with original shape (#, *) and followed by an expand to get the matching shape [*, #](1).
4085
4069
bool pastFutureValueOutputDoesNotMatchInputByOne =
4086
- sliceNode-> OutputDefs () [0 ]->Shape ()->dim_size () == 2 &&
4070
+ inputs [0 ]->Shape ()->dim_size () == 2 &&
4087
4071
outputs[0 ]->Shape ()->dim_size () == 3 &&
4088
4072
outputs[0 ]->Shape ()->dim (2 ).dim_value () == 1 ;
4089
4073
std::vector<onnxruntime::NodeArg *> concatOutputs;
@@ -4098,22 +4082,26 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
4098
4082
typeProto.mutable_tensor_type ()->mutable_shape ()->add_dim ()->set_dim_value (outputs[0 ]->Shape ()->dim (i).dim_value ());
4099
4083
}
4100
4084
UpdateONNXType (src->Output ().GetDataType (), typeProto);
4101
- NodeArg *arg = &graph->GetOrCreateNodeArg (outputNodeArgName + " _before_expand" , &typeProto);
4085
+ NodeArg *arg = &graph->GetOrCreateNodeArg (concatOutputNodeArgName + " _before_expand" , &typeProto);
4102
4086
concatOutputs.push_back (arg);
4103
4087
}
4104
4088
else
4105
- concatOutputs = outputs;
4106
-
4089
+ {
4090
+ TypeProto concatOutputTypeProto;
4091
+ UpdateONNXType (src->Outputs ()[0 ].GetDataType (), concatOutputTypeProto);
4092
+ onnxruntime::NodeArg *concatOutputArg = &graph->GetOrCreateNodeArg (concatOutputNodeArgName, &concatOutputTypeProto);
4093
+ concatOutputs.push_back (concatOutputArg);
4094
+ }
4107
4095
if (past)
4108
4096
{
4109
4097
concatNode = &graph->AddNode (UniqueNodeNameStorage::GetUniqueNodeName (src), " Concat" , " " ,
4110
- {const_cast <NodeArg*>(initValueExpand->OutputDefs ()[0 ]), const_cast <NodeArg*>(sliceNode-> OutputDefs () [0 ])},
4098
+ {const_cast <NodeArg*>(initValueExpand->OutputDefs ()[0 ]), const_cast <NodeArg*>(inputs [0 ])},
4111
4099
concatOutputs);
4112
4100
}
4113
4101
else
4114
4102
{
4115
4103
concatNode = &graph->AddNode (UniqueNodeNameStorage::GetUniqueNodeName (src), " Concat" , " " ,
4116
- {const_cast <NodeArg*>(sliceNode-> OutputDefs () [0 ]), const_cast <NodeArg*>(initValueExpand->OutputDefs ()[0 ])},
4104
+ {const_cast <NodeArg*>(inputs [0 ]), const_cast <NodeArg*>(initValueExpand->OutputDefs ()[0 ])},
4117
4105
concatOutputs);
4118
4106
}
4119
4107
// concat on sequence axis
@@ -4122,16 +4110,53 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
4122
4110
if (pastFutureValueOutputDoesNotMatchInputByOne)
4123
4111
{
4124
4112
std::vector<int64_t > newShape = ToINTS (*outputs[0 ]->TypeAsProto ());
4125
- Node* expandNode = AddReshapeNode (const_cast <NodeArg &>(*concatNode->OutputDefs ()[0 ]),
4113
+ concatNode = AddReshapeNode (const_cast <NodeArg &>(*concatNode->OutputDefs ()[0 ]),
4126
4114
newShape, outputs[0 ]->Name (), graph);
4127
- functionNodes.emplace (src, expandNode);
4128
- return expandNode;
4115
+ }
4116
+
4117
+ int64_t sliceAxis = 0 , sliceStart, sliceEnd;
4118
+ if (past)
4119
+ {
4120
+ sliceStart = 0 ;
4121
+ sliceEnd = -offset;
4122
+ }
4123
+ else
4124
+ {
4125
+ sliceStart = offset;
4126
+ sliceEnd = std::numeric_limits<int64_t >::max ();
4127
+ }
4128
+
4129
+ Node* sliceNode = AddSliceNode (const_cast <NodeArg &>(*concatNode->OutputDefs ()[0 ]),
4130
+ { sliceAxis }, { sliceStart }, { sliceEnd }, outputName, graph);
4131
+ functionNodes.emplace (src, sliceNode);
4132
+
4133
+ // for segmented sequence to run in a consecutive way - delayed values are taken from previous runs,
4134
+ // we need to output rest of sliced frames
4135
+ if (past)
4136
+ {
4137
+ sliceStart = -offset;
4138
+ sliceEnd = std::numeric_limits<int64_t >::max ();
4129
4139
}
4130
4140
else
4131
4141
{
4132
- functionNodes. emplace (src, concatNode) ;
4133
- return concatNode ;
4142
+ sliceStart = 0 ;
4143
+ sliceEnd = offset ;
4134
4144
}
4145
+
4146
+ // get the delayed initial value shape right
4147
+ std::vector<int64_t > delayedInitialValuesOutputShape = ToINTS (*sliceNode->OutputDefs ()[0 ]->TypeAsProto ());
4148
+ delayedInitialValuesOutputShape[0 ] = offset;
4149
+ TypeProto delayedInitialValuesNodeArgType = ToTypeProto (delayedInitialValuesOutputShape, false );
4150
+ UpdateONNXType (src->Output ().GetDataType (), delayedInitialValuesNodeArgType);
4151
+
4152
+ std::string delayedInitialValuesNodeArgName = outputName + " _delayed_initial_values" ;
4153
+ NodeArg* delayedInitialValuesNodeArg = &graph->GetOrCreateNodeArg (
4154
+ delayedInitialValuesNodeArgName, &delayedInitialValuesNodeArgType);
4155
+
4156
+ Node* delayedInitialValues = AddSliceNode (const_cast <NodeArg &>(*concatNode->OutputDefs ()[0 ]),
4157
+ { sliceAxis }, { sliceStart }, { sliceEnd }, delayedInitialValuesNodeArgName, graph);
4158
+
4159
+ return sliceNode;
4135
4160
}
4136
4161
4137
4162
// the idea is to create an EyeLike node and slice the first slice for IsFirst, the last slice for IsLast op.
0 commit comments