@@ -10,9 +10,15 @@ def __init__(self, embed_size):
1010        """Load pretrained ResNet-152 and replace top fc layer.""" 
1111        super (EncoderCNN , self ).__init__ ()
1212        self .resnet  =  models .resnet152 (pretrained = True )
13-         self . resnet . fc   =   nn . Linear ( self . resnet . fc . in_features ,  embed_size ) 
13+         # For efficient memory usage. 
1414        for  param  in  self .resnet .parameters ():
1515            param .requires_grad  =  False 
16+         self .resnet .fc  =  nn .Linear (self .resnet .fc .in_features , embed_size )
17+         self .init_weights ()
18+     
19+     def  init_weights (self ):
20+         self .resnet .fc .weight .data .uniform_ (- 0.1 , 0.1 )
21+         self .resnet .fc .bias .data .fill_ (0 )
1622
1723    def  forward (self , images ):
1824        """Extract image feature vectors.""" 
@@ -30,6 +36,11 @@ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
3036        self .embed  =  nn .Embedding (vocab_size , embed_size )
3137        self .lstm  =  nn .LSTM (embed_size , hidden_size , num_layers )
3238        self .linear  =  nn .Linear (hidden_size , vocab_size )
39+     
40+     def  init_weights (self ):
41+         self .embed .weight .data .uniform_ (- 0.1 , 0.1 )
42+         self .linear .weigth .data .uniform_ (- 0.1 , 0.1 )
43+         self .linear .bias .data .fill_ (0 )
3344
3445    def  forward (self , features , captions , lengths ):
3546        """Decode image feature vectors and generate caption.""" 
0 commit comments