Skip to content

Commit 5fd9592

Browse files
author
Jaden Travnik
committed
Use torch.max instead of torch.nn.MaxPool1d
According to the original paper, there should be no restriction on the number of points within a point cloud. This PR updates the pointnet to allow for variable number of points within a pointcloud or from mini-batch to mini-batch. If one has a dataset that has different numbers of points within each point cloud, one way of using this is to upsample (similar to how images are padded with 0s when they are different sizes). However, one wants to minimize the number of added points because unlike image data, adding 0s in a point cloud changes the structure of the data. Instead, one should duplicate the fewest number of points such that each sample in a mini-batch has the same number of points but each mini-batch may have a different number of points per sample. To do this, one should sort their dataset by the number of points within each point cloud, then group the point clouds into sizes of the desired mini-batch. For example one mini-batch may have point clouds of sizes [901, 905, 905, ..., 945]. In order to upsample all pointclouds P in mini-batch_j, one randomly duplicates K points from a point cloud P_i with N points where K is the difference between the current point cloud size and the maximum point cloud size in mini-batch_j (K = max(P_i.size() for P_i in mini-batch_j) - N ). For example, the previous example mini-batch will now have sizes [945, 945, 945, ..., 945] because the first point cloud with size 901 had (945 - 901 = 44) points duplicated to it etc. Now that each mini-batch has the same number of points, one can train their network by randomly sampling these mini-batches.
1 parent 75dd61c commit 5fd9592

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import json
99
import codecs
1010
import numpy as np
11-
import progressbar
1211
import sys
1312
import torchvision.transforms as transforms
1413
import argparse

pointnet.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919

2020

2121
class STN3d(nn.Module):
22-
def __init__(self, num_points = 2500):
22+
def __init__(self):
2323
super(STN3d, self).__init__()
24-
self.num_points = num_points
2524
self.conv1 = torch.nn.Conv1d(3, 64, 1)
2625
self.conv2 = torch.nn.Conv1d(64, 128, 1)
2726
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
28-
self.mp1 = torch.nn.MaxPool1d(num_points)
2927
self.fc1 = nn.Linear(1024, 512)
3028
self.fc2 = nn.Linear(512, 256)
3129
self.fc3 = nn.Linear(256, 9)
@@ -43,7 +41,7 @@ def forward(self, x):
4341
x = F.relu(self.bn1(self.conv1(x)))
4442
x = F.relu(self.bn2(self.conv2(x)))
4543
x = F.relu(self.bn3(self.conv3(x)))
46-
x = self.mp1(x)
44+
x = torch.max(x, 2, keepdim=True)[0]
4745
x = x.view(-1, 1024)
4846

4947
x = F.relu(self.bn4(self.fc1(x)))
@@ -59,20 +57,19 @@ def forward(self, x):
5957

6058

6159
class PointNetfeat(nn.Module):
62-
def __init__(self, num_points = 2500, global_feat = True):
60+
def __init__(self, global_feat = True):
6361
super(PointNetfeat, self).__init__()
64-
self.stn = STN3d(num_points = num_points)
62+
self.stn = STN3d()
6563
self.conv1 = torch.nn.Conv1d(3, 64, 1)
6664
self.conv2 = torch.nn.Conv1d(64, 128, 1)
6765
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
6866
self.bn1 = nn.BatchNorm1d(64)
6967
self.bn2 = nn.BatchNorm1d(128)
7068
self.bn3 = nn.BatchNorm1d(1024)
71-
self.mp1 = torch.nn.MaxPool1d(num_points)
72-
self.num_points = num_points
7369
self.global_feat = global_feat
7470
def forward(self, x):
7571
batchsize = x.size()[0]
72+
n_pts = x.size()[2]
7673
trans = self.stn(x)
7774
x = x.transpose(2,1)
7875
x = torch.bmm(x, trans)
@@ -81,19 +78,18 @@ def forward(self, x):
8178
pointfeat = x
8279
x = F.relu(self.bn2(self.conv2(x)))
8380
x = self.bn3(self.conv3(x))
84-
x = self.mp1(x)
81+
x = torch.max(x, 2, keepdim=True)[0]
8582
x = x.view(-1, 1024)
8683
if self.global_feat:
8784
return x, trans
8885
else:
89-
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
86+
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
9087
return torch.cat([x, pointfeat], 1), trans
9188

9289
class PointNetCls(nn.Module):
93-
def __init__(self, num_points = 2500, k = 2):
90+
def __init__(self, k = 2):
9491
super(PointNetCls, self).__init__()
95-
self.num_points = num_points
96-
self.feat = PointNetfeat(num_points, global_feat=True)
92+
self.feat = PointNetfeat(global_feat=True)
9793
self.fc1 = nn.Linear(1024, 512)
9894
self.fc2 = nn.Linear(512, 256)
9995
self.fc3 = nn.Linear(256, k)
@@ -105,14 +101,13 @@ def forward(self, x):
105101
x = F.relu(self.bn1(self.fc1(x)))
106102
x = F.relu(self.bn2(self.fc2(x)))
107103
x = self.fc3(x)
108-
return F.log_softmax(x, dim=-1), trans
104+
return F.log_softmax(x, dim=0), trans
109105

110106
class PointNetDenseCls(nn.Module):
111-
def __init__(self, num_points = 2500, k = 2):
107+
def __init__(self, k = 2):
112108
super(PointNetDenseCls, self).__init__()
113-
self.num_points = num_points
114109
self.k = k
115-
self.feat = PointNetfeat(num_points, global_feat=False)
110+
self.feat = PointNetfeat(global_feat=False)
116111
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
117112
self.conv2 = torch.nn.Conv1d(512, 256, 1)
118113
self.conv3 = torch.nn.Conv1d(256, 128, 1)
@@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):
123118

124119
def forward(self, x):
125120
batchsize = x.size()[0]
121+
n_pts = x.size()[2]
126122
x, trans = self.feat(x)
127123
x = F.relu(self.bn1(self.conv1(x)))
128124
x = F.relu(self.bn2(self.conv2(x)))
129125
x = F.relu(self.bn3(self.conv3(x)))
130126
x = self.conv4(x)
131127
x = x.transpose(2,1).contiguous()
132128
x = F.log_softmax(x.view(-1,self.k), dim=-1)
133-
x = x.view(batchsize, self.num_points, self.k)
129+
x = x.view(batchsize, n_pts, self.k)
134130
return x, trans
135131

136132

0 commit comments

Comments
 (0)