7979
8080#======================== Loading data from numpy ========================#
8181a = np .array ([[1 ,2 ], [3 ,4 ]])
82- b = torch .from_numpy (a )
83- print (b )
84-
82+ b = torch .from_numpy (a ) # convert numpy array to torch tensor
83+ c = b .numpy () # convert torch tensor to numpy array
8584
8685
8786#===================== Implementing the input pipline =====================#
113112 # Your training code will be written here
114113 pass
115114
115+
116116#===================== Input pipline for custom dataset =====================#
117117# You should build custom dataset as below.
118118class CustomDataset (data .Dataset ):
@@ -123,14 +123,16 @@ def __init__(self):
123123 def __getitem__ (self , index ):
124124 # TODO
125125 # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
126- # 2. Return a data pair (e.g. image and label).
126+ # 2. Preprocess the data (e.g. torchvision.Transform).
127+ # 3. Return a data pair (e.g. image and label).
127128 pass
128129 def __len__ (self ):
129130 # You should change 0 to the total size of your dataset.
130131 return 0
131132
132133# Then, you can just use prebuilt torch's data loader.
133- train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
134+ custom_dataset = CustomDataset ()
135+ train_loader = torch .utils .data .DataLoader (dataset = custom_dataset ,
134136 batch_size = 100 ,
135137 shuffle = True ,
136138 num_workers = 2 )
@@ -153,6 +155,11 @@ def __len__(self):
153155print (outputs .size ()) # (10, 100)
154156
155157
156- #============================ Save and load model ============================#
158+ #============================ Save and load the model ============================#
159+ # Save and load the entire model.
157160torch .save (resnet , 'model.pkl' )
158- model = torch .load ('model.pkl' )
161+ model = torch .load ('model.pkl' )
162+
163+ # Save and load only the model parameters(recommended).
164+ torch .save (resnet .state_dict (), 'params.pkl' )
165+ resnet .load_state_dict (torch .load ('params.pkl' ))
0 commit comments