-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathSegDataset.h
32 lines (29 loc) · 1.1 KB
/
SegDataset.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef SEGDATASET_H
#define SEGDATASET_H
#include"util.h"
#include"fstream"
#include "json.hpp"
#include<opencv2/opencv.hpp>
void show_mask(std::string json_path, std::string image_type = ".jpg");
void draw_mask(std::string json_path, cv::Mat &mask);
class SegDataset :public torch::data::Dataset<SegDataset>
{
public:
SegDataset(int resize_width, int resize_height, std::vector<std::string> list_images,
std::vector<std::string> list_labels, std::vector<std::string> name_list);
// Override get() function to return tensor at location index
torch::data::Example<> get(size_t index) override;
// Return the length of data
torch::optional<size_t> size() const override {
return list_labels.size();
};
private:
void draw_mask(std::string json_path, cv::Mat &mask);
int resize_width = 512; int resize_height = 512;
std::vector<std::string> name_list = {};
std::map<std::string, int> name2index = {};
std::map<std::string, cv::Scalar> name2color = {};
std::vector<std::string> list_images;
std::vector<std::string> list_labels;
};
#endif // SEGDATASET_H