Skip to content

Commit 280ec14

Browse files
committed
support Pooling ops with Sequence axis
1 parent c7bc93f commit 280ec14

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,12 @@ class CNTKToONNXHelper
845845
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
846846
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
847847

848+
static onnxruntime::Node* CreatePoolingNode(const FunctionPtr& src,
849+
onnxruntime::Graph* graph,
850+
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
851+
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
852+
std::vector<ScanLoop>& scanLoops, int createLoopIndex);
853+
848854
static onnxruntime::Node* CreateConvolutionNode(const FunctionPtr& src,
849855
onnxruntime::Graph* graph,
850856
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@@ -5396,6 +5402,9 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
53965402
std::string cntkOpName = ToLegacyString(ToUTF8(src->OpName()));
53975403
std::string onnxOpName = ToOPName(src);
53985404

5405+
if (src->OpName() == L"Pooling")
5406+
std::cout << "";
5407+
53995408
// TODO: uncomment this code once bidirectional LSTM is supprted.
54005409
//if (cntkOpName == "Splice")
54015410
//{
@@ -5629,7 +5638,10 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
56295638
else
56305639
return CreateConvolutionNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
56315640
}
5632-
5641+
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
5642+
{
5643+
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
5644+
}
56335645
//
56345646
// If this block node equivalent to a primitive ONNX OP, then treated as such.
56355647
// And just maps its argument to ONNX node.
@@ -7087,7 +7099,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
70877099
auto lowerPad = ToINTS(src->Attributes()[L"lowerPad"].Value<NDShape>());
70887100
auto upperPad = ToINTS(src->Attributes()[L"upperPad"].Value<NDShape>());
70897101

7090-
if (IsPadValueValid(lowerPad, upperPad, autoPadding, ceilOutDim))
7102+
if (IsPadValueValid(lowerPad, upperPad, autoPadding, ceilOutDim) && !(src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis()))
70917103
{
70927104
if (ceilOutDim)
70937105
ValidatePadValueForCeilOutDim(lowerPad, upperPad, autoPadding, kernelShape, inputShape, strides,
@@ -8605,6 +8617,52 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu
86058617
return activationNode;
86068618
}
86078619

8620+
onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
8621+
onnxruntime::Graph* graph,
8622+
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
8623+
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
8624+
std::vector<ScanLoop>& scanLoops, int createLoopIndex)
8625+
{
8626+
if (!src->Inputs()[0].HasBatchAxis() || !src->Inputs()[0].HasSequenceAxis())
8627+
LogicError("CreatePoolingNode is only to handle MaxPool with batch and sequence dimensions.");
8628+
8629+
std::vector<onnxruntime::NodeArg *> inputs;
8630+
ProcessInputs(src, graph, functionNodes, variableNodes, inputs,
8631+
scanLoops, createLoopIndex);
8632+
8633+
std::vector<onnxruntime::NodeArg *> outputs;
8634+
ProcessOutputs(src, inputs, outputs, graph);
8635+
8636+
// Max/AveragePool takes input of shape [N, C, H, W] or [N, C, D1, D2, ..., Dn]. CNTK input needs to be reshaped to match it.
8637+
// reshape [#, *][C, H, W] to [-1, C, H, W]
8638+
// onnx Max/AveragePool
8639+
// reshape [-1, C_out, H_out, W_out] to [#, *][C_out, H_out, W_out]
8640+
vector<int64_t> newDimInputToPooling;
8641+
// collapse extra dims into one axis as N for ONNX Conv
8642+
newDimInputToPooling.push_back(-1);
8643+
for (int i = 2; i < inputs[0]->Shape()->dim_size(); i++)
8644+
{
8645+
// copy C, H, W
8646+
if (!inputs[0]->Shape()->dim(i).has_dim_value())
8647+
LogicError("Max/AveragePool: feature dimensions need to have dim value.");
8648+
newDimInputToPooling.push_back(inputs[0]->Shape()->dim(i).dim_value());
8649+
}
8650+
8651+
onnxruntime::Node* preReshape = AddReshapeNode(*inputs[0], newDimInputToPooling, inputs[0]->Name() + "_reshaped_for_max_pool", graph);
8652+
const std::vector<onnxruntime::NodeArg *> pooling_inputs({const_cast<NodeArg *>(preReshape->OutputDefs()[0])});
8653+
TypeProto poolingOutputTypeProto;
8654+
UpdateONNXType(src->Outputs()[0].GetDataType(), poolingOutputTypeProto);
8655+
8656+
NodeArg *poolingOutputArg = &graph->GetOrCreateNodeArg(outputs[0]->Name() + "_pooling_of_reshaped", &poolingOutputTypeProto);
8657+
8658+
onnxruntime::Node* poolingNode = AddNode(src, graph, pooling_inputs, { poolingOutputArg });
8659+
8660+
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
8661+
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
8662+
8663+
return poolingNode;
8664+
}
8665+
86088666
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,
86098667
onnxruntime::Graph* graph,
86108668
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,

bindings/python/cntk/tests/onnx_op_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,18 @@ def test_AveragePool(tmpdir, dtype, device_id):
423423

424424
verify_one_input(model, img, tmpdir, 'AveragePool_2', device)
425425

426+
#AveragePool
427+
@pytest.mark.parametrize("dtype", DType_Config)
428+
def test_AvergaePoolWithSequenceAxis(tmpdir, dtype, device_id):
429+
if device_id == -1 and dtype == np.float16:
430+
pytest.skip('Test is skipped on CPU with float16 data')
431+
device = cntk_device(device_id)
432+
with C.default_options(dtype=dtype):
433+
img = np.reshape(np.arange(16, dtype = dtype), [1, 4, 4])
434+
x = C.sequence.input_variable(img.shape)
435+
model = C.pooling(x, C.AVG_POOLING, (2,2), (2,2))
436+
verify_sequence_model(model, np.reshape(img, [1, 1, 1, 4, 4]), tmpdir, "AveragePoolWithSeq_1", resave = False, bypass_load_into_cntk = True)
437+
426438
#BatchNormalization
427439
def verify_BN(x, init_scale, init_bias, mean, var, epsilon, spatial, tmpdir, dtype):
428440
with C.default_options(dtype = dtype):
@@ -1311,7 +1323,7 @@ def test_Max(tmpdir, dtype):
13111323

13121324
#MaxPool
13131325
@pytest.mark.parametrize("dtype", DType_Config)
1314-
def test_MaxPool(tmpdir, dtype, device_id):
1326+
def test_MaxPool(tmpdir, dtype, device_id):
13151327
if device_id == -1 and dtype == np.float16:
13161328
pytest.skip('Test is skipped on CPU with float16 data')
13171329
device = cntk_device(device_id)
@@ -1327,6 +1339,18 @@ def test_MaxPool(tmpdir, dtype, device_id):
13271339
model = C.pooling(x, C.MAX_POOLING, (3, 3), (2, 2), auto_padding=[False, False, False], ceil_out_dim=True)
13281340
verify_one_input(model, img, tmpdir, 'MaxPool_2', device)
13291341

1342+
#MaxPool
1343+
@pytest.mark.parametrize("dtype", DType_Config)
1344+
def test_MaxPoolWithSequenceAxis(tmpdir, dtype, device_id):
1345+
if device_id == -1 and dtype == np.float16:
1346+
pytest.skip('Test is skipped on CPU with float16 data')
1347+
device = cntk_device(device_id)
1348+
with C.default_options(dtype=dtype):
1349+
img = np.reshape(np.arange(16, dtype = dtype), [1, 4, 4])
1350+
x = C.sequence.input_variable(img.shape)
1351+
model = C.pooling(x, C.MAX_POOLING, (2,2), (2,2))
1352+
verify_sequence_model(model, np.reshape(img, [1, 1, 1, 4, 4]), tmpdir, "MaxPoolWithSeq_1", resave = False, bypass_load_into_cntk = True)
1353+
13301354
#MaxRoiPool
13311355
@pytest.mark.parametrize("dtype", DType_Config)
13321356
def test_MaxRoiPool(tmpdir, dtype):

0 commit comments

Comments
 (0)