@@ -10,9 +10,15 @@ def __init__(self, embed_size):
1010 """Load pretrained ResNet-152 and replace top fc layer."""
1111 super (EncoderCNN , self ).__init__ ()
1212 self .resnet = models .resnet152 (pretrained = True )
13- self . resnet . fc = nn . Linear ( self . resnet . fc . in_features , embed_size )
13+ # For efficient memory usage.
1414 for param in self .resnet .parameters ():
1515 param .requires_grad = False
16+ self .resnet .fc = nn .Linear (self .resnet .fc .in_features , embed_size )
17+ self .init_weights ()
18+
19+ def init_weights (self ):
20+ self .resnet .fc .weight .data .uniform_ (- 0.1 , 0.1 )
21+ self .resnet .fc .bias .data .fill_ (0 )
1622
1723 def forward (self , images ):
1824 """Extract image feature vectors."""
@@ -30,6 +36,11 @@ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
3036 self .embed = nn .Embedding (vocab_size , embed_size )
3137 self .lstm = nn .LSTM (embed_size , hidden_size , num_layers )
3238 self .linear = nn .Linear (hidden_size , vocab_size )
39+
40+ def init_weights (self ):
41+ self .embed .weight .data .uniform_ (- 0.1 , 0.1 )
42+ self .linear .weigth .data .uniform_ (- 0.1 , 0.1 )
43+ self .linear .bias .data .fill_ (0 )
3344
3445 def forward (self , features , captions , lengths ):
3546 """Decode image feature vectors and generate caption."""
0 commit comments