@@ -36,20 +36,33 @@ def sort_pack_padded_sequence(input, lengths):
3636 tmp = pack_padded_sequence (input [indices ], sorted_lengths , batch_first = True )
3737 inv_ix = indices .clone ()
3838 inv_ix [indices ] = torch .arange (0 ,len (indices )).type_as (inv_ix )
39+ # inv_ix = torch.arange(0, len(indices)).type_as(indices)[indices]
3940 return tmp , inv_ix
4041
4142def pad_unsort_packed_sequence (input , inv_ix ):
4243 tmp , _ = pad_packed_sequence (input , batch_first = True )
4344 tmp = tmp [inv_ix ]
4445 return tmp
4546
46- def pack_wrapper (module , att_feats , att_masks ):
47+ def pack_wrapper_old (module , att_feats , att_masks ):
4748 if att_masks is not None :
4849 packed , inv_ix = sort_pack_padded_sequence (att_feats , att_masks .data .long ().sum (1 ))
4950 return pad_unsort_packed_sequence (PackedSequence (module (packed [0 ]), packed [1 ]), inv_ix )
5051 else :
5152 return module (att_feats )
5253
54+ def pack_wrapper (module , att_feats , att_masks ):
55+ if att_masks is not None :
56+ packed = pack_padded_sequence (att_feats , att_masks .data .long ().sum (1 ), enforce_sorted = False , batch_first = True )
57+ padded = pad_packed_sequence (PackedSequence (
58+ data = module (packed .data ), sorted_indices = packed .sorted_indices ,
59+ unsorted_indices = packed .unsorted_indices , batch_sizes = packed .batch_sizes
60+ ),
61+ batch_first = True )[0 ]
62+ return padded
63+ else :
64+ return module (att_feats )
65+
5366class AttModel (CaptionModel ):
5467 def __init__ (self , opt ):
5568 super (AttModel , self ).__init__ ()
@@ -117,12 +130,12 @@ def _prepare_feature(self, fc_feats, att_feats, att_masks):
117130
118131 # embed fc and att feats
119132 fc_feats = self .fc_embed (fc_feats )
120- att_feats = pack_wrapper (self .att_embed , att_feats , att_masks )
133+ att_feats_wrapped = pack_wrapper (self .att_embed , att_feats , att_masks )
121134
122135 # Project the attention feats first to reduce memory and computation comsumptions.
123- p_att_feats = self .ctx2att (att_feats )
136+ p_att_feats = self .ctx2att (att_feats_wrapped )
124137
125- return fc_feats , att_feats , p_att_feats , att_masks
138+ return fc_feats , att_feats_wrapped , p_att_feats , att_masks
126139
127140 def _forward (self , fc_feats , att_feats , seq , att_masks = None ):
128141 batch_size = fc_feats .size (0 )
@@ -353,13 +366,54 @@ def unflat_view(tensor):
353366 assert tensor .size (0 ) == batch_size * per_image_dim
354367 return tensor .view ((batch_size , per_image_dim ) + tensor .size ()[1 :])
355368
356- # p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
369+ prepped_non_neighbor_feats = self ._prepare_feature (fc_feats , att_feats , att_masks )
357370 # p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a =
358371 prepped_feats = self ._prepare_feature (
359372 flat_view (neighbor_batch ['fc_feats' ].to (device )),
360373 flat_view (neighbor_batch ['att_feats' ].to (device )),
361374 flat_view (neighbor_batch ['att_masks' ].to (device )) if neighbor_batch ['att_masks' ] is not None else None ,
362375 )
376+ prepped_feats_trunc = self ._prepare_feature (
377+ flat_view (neighbor_batch ['fc_feats' ].to (device )),
378+ flat_view (neighbor_batch ['att_feats' ][:,:,:prepped_non_neighbor_feats [1 ].size (1 )].to (device )),
379+ flat_view (neighbor_batch ['att_masks' ][:,:,:prepped_non_neighbor_feats [1 ].size (1 )].to (device )) if neighbor_batch ['att_masks' ] is not None else None ,
380+ )
381+
382+ clipped = self .clip_att (
383+ att_feats ,
384+ att_masks ,
385+ )
386+ clipped_wrapped = pack_wrapper (self .att_embed , * clipped )
387+
388+ clipped_trunc = self .clip_att (
389+ flat_view (neighbor_batch ['att_feats' ][:,:,:prepped_non_neighbor_feats [1 ].size (1 )].to (device )),
390+ flat_view (neighbor_batch ['att_masks' ][:,:,:prepped_non_neighbor_feats [1 ].size (1 )].to (device )) if neighbor_batch ['att_masks' ] is not None else None ,
391+ )
392+ clipped_trunc_wrapped = pack_wrapper (self .att_embed , * clipped_trunc )
393+
394+ embedded = self .att_embed (clipped [0 ])
395+ embedded_trunc = self .att_embed (clipped_trunc [0 ])
396+
397+ # this passes:
398+ # torch.allclose(prepped_non_neighbor_feats[0].view(10, -1), prepped_feats_trunc[0].view(10, 9, -1)[:,0])
399+ # this fails:
400+ # torch.allclose(prepped_non_neighbor_feats[1].view(10, -1), prepped_feats_trunc[1].view(10, 9, -1)[:,0])
401+
402+ # these both pass
403+ # torch.allclose(clipped[0].view(10, -1), clipped_trunc[0].view(10, 9, -1)[:,0])
404+ # torch.allclose(clipped[1].view(10, -1), clipped_trunc[1].view(10, 9, -1)[:,0])
405+
406+ # this fails:
407+ # torch.allclose(clipped_trunc_wrapped.view(10, 9, -1)[:,0], clipped_wrapped.view(10, -1))
408+
409+ # torch.allclose(clipped[0].view(10, -1), clipped_trunc[0].view(10, 9, -1)[:,0])
410+
411+ # torch.allclose(self.att_embed(clipped[0]).view(10, -1), self.att_embed(clipped_trunc[0]).view(10, 9, -1)[:,0])
412+ # Out[8]: False
413+ # torch.allclose(self.att_embed(clipped[0]).view(10, -1), self.att_embed(clipped_trunc[0]).view(10, 9, -1)[:,0], atol=1e-4)
414+ # Out[22]: True
415+ # torch.allclose(self.att_embed(clipped[0]), self.att_embed(clipped_trunc[0].view(10, 9, *clipped[0].size()[1:])[:,0]))
416+ # Out[9]: True
363417
364418 assert beam_size <= self .vocab_size + 1 , 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
365419 seq = fc_feats .new_full ((batch_size * sample_n , self .seq_length ), self .pad_idx , dtype = torch .long )
@@ -372,16 +426,21 @@ def unflat_view(tensor):
372426
373427 # first step, feed bos
374428 it = fc_feats .new_full ([batch_size * per_image_dim ], self .bos_idx , dtype = torch .long )
429+ it_non_neighbor = fc_feats .new_full ([batch_size ], self .bos_idx , dtype = torch .long )
375430 # batch_size*per_image_dim x V
376431 logprobs , state = self .get_logprobs_state (it , * (prepped_feats + (state ,)))
432+ logprobs_non_neighbor , state_non_neighbor = self .get_logprobs_state (it_non_neighbor , * (prepped_non_neighbor_feats + (self .init_hidden (batch_size ),)))
377433 # logprobs, state = self.get_logprobs_state(it, p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a, state)
378434
379- # (batch_size*beam_size) x per_image_view x ...
435+ # (batch_size*beam_size* per_image_view) x ...
380436 repeated_feats = (combine_first_two (ten ) for ten in utils .repeat_tensors (
381437 beam_size ,
382438 [unflat_view (t ) for t in prepped_feats ]
383439 # [p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a]
384440 ))
441+ repeated_non_neighbor_feats = utils .repeat_tensors (beam_size ,
442+ prepped_non_neighbor_feats )
443+ done_beams_non_neighbor = self .beam_search (state_non_neighbor , logprobs_non_neighbor , * repeated_non_neighbor_feats , opt = opt )
385444 self .done_beams = self .contrastive_beam_search (
386445 state , logprobs , * repeated_feats , opt = opt
387446 )
0 commit comments