Skip to content

Commit 2bbe5f7

Browse files
committed
Add GetEndPoints of Reader.
We can get endpoints of a reader chain.
1 parent d3a4848 commit 2bbe5f7

File tree

4 files changed

+98
-0
lines changed

4 files changed

+98
-0
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
2727
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
2828

2929
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
30+
cc_test(reader_test SRCS reader_test.cc DEPS reader)
3031

3132
cc_test(variable_test SRCS variable_test.cc)
3233

paddle/fluid/framework/reader.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,40 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/reader.h"
16+
#include <deque>
1617

1718
namespace paddle {
1819
namespace framework {
1920
ReaderBase::~ReaderBase() {}
2021

22+
void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) {
23+
decorated_readers_.emplace(decorated_reader);
24+
}
25+
void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) {
26+
auto it = decorated_readers_.find(decorated_reader);
27+
PADDLE_ENFORCE(it != decorated_readers_.end(),
28+
"Cannot find the decorated reader to erase");
29+
decorated_readers_.erase(it);
30+
}
31+
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
32+
std::unordered_set<ReaderBase *> result;
33+
std::deque<ReaderBase *> queue;
34+
queue.emplace_back(this);
35+
while (!queue.empty()) { // BFS search
36+
auto *front = queue.front();
37+
queue.pop_front();
38+
if (front->decorated_readers_.empty()) {
39+
result.emplace(front);
40+
} else {
41+
for (ReaderBase *reader : front->decorated_readers_) {
42+
queue.emplace_back(reader);
43+
}
44+
}
45+
}
46+
47+
return result;
48+
}
49+
2150
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
2251

2352
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
@@ -37,5 +66,6 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
3766
}
3867
}
3968
}
69+
DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); }
4070
} // namespace framework
4171
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <memory>
18+
#include <unordered_set>
1819
#include <vector>
1920

2021
#include "paddle/fluid/framework/ddim.h"
@@ -31,15 +32,31 @@ class ReaderBase {
3132
virtual void ReInit() = 0;
3233

3334
virtual ~ReaderBase();
35+
36+
// Return the readers which are the end of decorating chain. Basically
37+
// they are readers just before read op.
38+
std::unordered_set<ReaderBase*> GetEndPoints();
39+
40+
private:
41+
friend class DecoratedReader;
42+
// These methods can be only invoked inside DecoratedReader to record the
43+
// decorating chain.
44+
void InsertDecoratedReader(ReaderBase* decorated_reader);
45+
void EraseDecoratedReader(ReaderBase* decorated_reader);
46+
// A set of which readers that decorated this reader.
47+
std::unordered_set<ReaderBase*> decorated_readers_;
3448
};
3549

3650
class DecoratedReader : public ReaderBase {
3751
public:
3852
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
3953
: ReaderBase(), reader_(reader) {
4054
PADDLE_ENFORCE_NOT_NULL(reader_);
55+
reader_->InsertDecoratedReader(this);
4156
}
4257

58+
~DecoratedReader();
59+
4360
void ReInit() override { reader_->ReInit(); }
4461

4562
protected:
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/reader.h"
16+
#include <memory>
17+
#include "gtest/gtest.h"
18+
19+
class StubDecoratedReader : public paddle::framework::DecoratedReader {
20+
public:
21+
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
22+
: DecoratedReader(reader) {}
23+
24+
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
25+
};
26+
27+
class StubRootReader : public paddle::framework::ReaderBase {
28+
public:
29+
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
30+
void ReInit() override {}
31+
};
32+
33+
TEST(READER, decorate_chain) {
34+
auto root = std::make_shared<StubRootReader>();
35+
auto end_point1 = StubDecoratedReader(root);
36+
auto end_point2 = StubDecoratedReader(root);
37+
38+
{
39+
auto endpoints = root->GetEndPoints();
40+
ASSERT_EQ(endpoints.size(), 2U);
41+
ASSERT_NE(endpoints.count(&end_point1), 0);
42+
ASSERT_NE(endpoints.count(&end_point2), 0);
43+
}
44+
45+
{
46+
auto end_point3 = StubDecoratedReader(root);
47+
ASSERT_EQ(root->GetEndPoints().size(), 3U);
48+
}
49+
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
50+
}

0 commit comments

Comments
 (0)