Skip to content

Commit a29ce54

Browse files
committed
some examples are edited
1 parent a7f2f7f commit a29ce54

File tree

1 file changed

+14
-7
lines changed
  • tutorials/00 - PyTorch Basics

1 file changed

+14
-7
lines changed

tutorials/00 - PyTorch Basics/main.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@
7979

8080
#======================== Loading data from numpy ========================#
8181
a = np.array([[1,2], [3,4]])
82-
b = torch.from_numpy(a)
83-
print (b)
84-
82+
b = torch.from_numpy(a) # convert numpy array to torch tensor
83+
c = b.numpy() # convert torch tensor to numpy array
8584

8685

8786
#===================== Implementing the input pipline =====================#
@@ -113,6 +112,7 @@
113112
# Your training code will be written here
114113
pass
115114

115+
116116
#===================== Input pipline for custom dataset =====================#
117117
# You should build custom dataset as below.
118118
class CustomDataset(data.Dataset):
@@ -123,14 +123,16 @@ def __init__(self):
123123
def __getitem__(self, index):
124124
# TODO
125125
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
126-
# 2. Return a data pair (e.g. image and label).
126+
# 2. Preprocess the data (e.g. torchvision.Transform).
127+
# 3. Return a data pair (e.g. image and label).
127128
pass
128129
def __len__(self):
129130
# You should change 0 to the total size of your dataset.
130131
return 0
131132

132133
# Then, you can just use prebuilt torch's data loader.
133-
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
134+
custom_dataset = CustomDataset()
135+
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
134136
batch_size=100,
135137
shuffle=True,
136138
num_workers=2)
@@ -153,6 +155,11 @@ def __len__(self):
153155
print (outputs.size()) # (10, 100)
154156

155157

156-
#============================ Save and load model ============================#
158+
#============================ Save and load the model ============================#
159+
# Save and load the entire model.
157160
torch.save(resnet, 'model.pkl')
158-
model = torch.load('model.pkl')
161+
model = torch.load('model.pkl')
162+
163+
# Save and load only the model parameters(recommended).
164+
torch.save(resnet.state_dict(), 'params.pkl')
165+
resnet.load_state_dict(torch.load('params.pkl'))

0 commit comments

Comments
 (0)