-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathUNet.cpp
26 lines (22 loc) · 1.11 KB
/
UNet.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include "UNet.h"
UNetImpl::UNetImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
std::vector<int> decoder_channels, bool use_attention){
num_classes = _num_classes;
std::vector<int> encoder_channels = BasicChannels;
if(!name2layers.count(encoder_name)) throw "encoder name must in {resnet18, resnet34, resnet50, resnet101}";
if(encoder_name!="resnet18" && encoder_name!="resnet34"){
encoder_channels = BottleChannels;
}
encoder = pretrained_resnet(1000, encoder_name, pretrained_path);
decoder = UNetDecoder(encoder_channels,decoder_channels, encoder_depth, use_attention, false);
segmentation_head = SegmentationHead(decoder_channels[decoder_channels.size()-1], num_classes, 1, 1);
register_module("encoder",encoder);
register_module("decoder",decoder);
register_module("segmentation_head",segmentation_head);
}
torch::Tensor UNetImpl::forward(torch::Tensor x){
std::vector<torch::Tensor> features = encoder->features(x);
x = decoder->forward(features);
x = segmentation_head->forward(x);
return x;
}