Skip to content

Commit dcbe46f

Browse files
committed
add debugging
1 parent 7fbbd4a commit dcbe46f

File tree

5 files changed

+79
-16
lines changed

5 files changed

+79
-16
lines changed

captioning/data/dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def collate_func(self, batch, split):
231231
# #sort by att_feat length
232232
# fc_batch, att_batch, label_batch, gts, infos = \
233233
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
234-
fc_batch, att_batch, label_batch, gts, infos = \
235-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
234+
# fc_batch, att_batch, label_batch, gts, infos = \
235+
# zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
236236
data = {}
237237
data['fc_feats'] = np.stack(fc_batch)
238238
# merge att_feats

captioning/data/pth_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def collate_func(self, batch):
218218
# #sort by att_feat length
219219
# fc_batch, att_batch, label_batch, gts, infos = \
220220
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
221-
fc_batch, att_batch, label_batch, gts, infos = \
222-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
221+
# fc_batch, att_batch, label_batch, gts, infos = \
222+
# zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
223223
data = {}
224224
data['fc_feats'] = np.stack(fc_batch)
225225
# merge att_feats

captioning/models/AttModel.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,33 @@ def sort_pack_padded_sequence(input, lengths):
3636
tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
3737
inv_ix = indices.clone()
3838
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
39+
# inv_ix = torch.arange(0, len(indices)).type_as(indices)[indices]
3940
return tmp, inv_ix
4041

4142
def pad_unsort_packed_sequence(input, inv_ix):
4243
tmp, _ = pad_packed_sequence(input, batch_first=True)
4344
tmp = tmp[inv_ix]
4445
return tmp
4546

46-
def pack_wrapper(module, att_feats, att_masks):
47+
def pack_wrapper_old(module, att_feats, att_masks):
4748
if att_masks is not None:
4849
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
4950
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
5051
else:
5152
return module(att_feats)
5253

54+
def pack_wrapper(module, att_feats, att_masks):
55+
if att_masks is not None:
56+
packed = pack_padded_sequence(att_feats, att_masks.data.long().sum(1), enforce_sorted=False, batch_first=True)
57+
padded = pad_packed_sequence(PackedSequence(
58+
data=module(packed.data), sorted_indices=packed.sorted_indices,
59+
unsorted_indices=packed.unsorted_indices, batch_sizes=packed.batch_sizes
60+
),
61+
batch_first=True)[0]
62+
return padded
63+
else:
64+
return module(att_feats)
65+
5366
class AttModel(CaptionModel):
5467
def __init__(self, opt):
5568
super(AttModel, self).__init__()
@@ -117,12 +130,12 @@ def _prepare_feature(self, fc_feats, att_feats, att_masks):
117130

118131
# embed fc and att feats
119132
fc_feats = self.fc_embed(fc_feats)
120-
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
133+
att_feats_wrapped = pack_wrapper(self.att_embed, att_feats, att_masks)
121134

122135
# Project the attention feats first to reduce memory and computation comsumptions.
123-
p_att_feats = self.ctx2att(att_feats)
136+
p_att_feats = self.ctx2att(att_feats_wrapped)
124137

125-
return fc_feats, att_feats, p_att_feats, att_masks
138+
return fc_feats, att_feats_wrapped, p_att_feats, att_masks
126139

127140
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
128141
batch_size = fc_feats.size(0)
@@ -353,13 +366,54 @@ def unflat_view(tensor):
353366
assert tensor.size(0) == batch_size * per_image_dim
354367
return tensor.view((batch_size, per_image_dim) + tensor.size()[1:])
355368

356-
# p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
369+
prepped_non_neighbor_feats = self._prepare_feature(fc_feats, att_feats, att_masks)
357370
# p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a =
358371
prepped_feats = self._prepare_feature(
359372
flat_view(neighbor_batch['fc_feats'].to(device)),
360373
flat_view(neighbor_batch['att_feats'].to(device)),
361374
flat_view(neighbor_batch['att_masks'].to(device)) if neighbor_batch['att_masks'] is not None else None,
362375
)
376+
prepped_feats_trunc = self._prepare_feature(
377+
flat_view(neighbor_batch['fc_feats'].to(device)),
378+
flat_view(neighbor_batch['att_feats'][:,:,:prepped_non_neighbor_feats[1].size(1)].to(device)),
379+
flat_view(neighbor_batch['att_masks'][:,:,:prepped_non_neighbor_feats[1].size(1)].to(device)) if neighbor_batch['att_masks'] is not None else None,
380+
)
381+
382+
clipped = self.clip_att(
383+
att_feats,
384+
att_masks,
385+
)
386+
clipped_wrapped = pack_wrapper(self.att_embed, *clipped)
387+
388+
clipped_trunc = self.clip_att(
389+
flat_view(neighbor_batch['att_feats'][:,:,:prepped_non_neighbor_feats[1].size(1)].to(device)),
390+
flat_view(neighbor_batch['att_masks'][:,:,:prepped_non_neighbor_feats[1].size(1)].to(device)) if neighbor_batch['att_masks'] is not None else None,
391+
)
392+
clipped_trunc_wrapped = pack_wrapper(self.att_embed, *clipped_trunc)
393+
394+
embedded = self.att_embed(clipped[0])
395+
embedded_trunc = self.att_embed(clipped_trunc[0])
396+
397+
# this passes:
398+
# torch.allclose(prepped_non_neighbor_feats[0].view(10, -1), prepped_feats_trunc[0].view(10, 9, -1)[:,0])
399+
# this fails:
400+
# torch.allclose(prepped_non_neighbor_feats[1].view(10, -1), prepped_feats_trunc[1].view(10, 9, -1)[:,0])
401+
402+
# these both pass
403+
# torch.allclose(clipped[0].view(10, -1), clipped_trunc[0].view(10, 9, -1)[:,0])
404+
# torch.allclose(clipped[1].view(10, -1), clipped_trunc[1].view(10, 9, -1)[:,0])
405+
406+
# this fails:
407+
# torch.allclose(clipped_trunc_wrapped.view(10, 9, -1)[:,0], clipped_wrapped.view(10, -1))
408+
409+
# torch.allclose(clipped[0].view(10, -1), clipped_trunc[0].view(10, 9, -1)[:,0])
410+
411+
# torch.allclose(self.att_embed(clipped[0]).view(10, -1), self.att_embed(clipped_trunc[0]).view(10, 9, -1)[:,0])
412+
# Out[8]: False
413+
# torch.allclose(self.att_embed(clipped[0]).view(10, -1), self.att_embed(clipped_trunc[0]).view(10, 9, -1)[:,0], atol=1e-4)
414+
# Out[22]: True
415+
# torch.allclose(self.att_embed(clipped[0]), self.att_embed(clipped_trunc[0].view(10, 9, *clipped[0].size()[1:])[:,0]))
416+
# Out[9]: True
363417

364418
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
365419
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
@@ -372,16 +426,21 @@ def unflat_view(tensor):
372426

373427
# first step, feed bos
374428
it = fc_feats.new_full([batch_size*per_image_dim], self.bos_idx, dtype=torch.long)
429+
it_non_neighbor = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
375430
# batch_size*per_image_dim x V
376431
logprobs, state = self.get_logprobs_state(it, *(prepped_feats + (state,)))
432+
logprobs_non_neighbor, state_non_neighbor = self.get_logprobs_state(it_non_neighbor, *(prepped_non_neighbor_feats + (self.init_hidden(batch_size),)))
377433
# logprobs, state = self.get_logprobs_state(it, p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a, state)
378434

379-
# (batch_size*beam_size) x per_image_view x ...
435+
# (batch_size*beam_size*per_image_view) x ...
380436
repeated_feats = (combine_first_two(ten) for ten in utils.repeat_tensors(
381437
beam_size,
382438
[unflat_view(t) for t in prepped_feats]
383439
# [p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a]
384440
))
441+
repeated_non_neighbor_feats = utils.repeat_tensors(beam_size,
442+
prepped_non_neighbor_feats)
443+
done_beams_non_neighbor = self.beam_search(state_non_neighbor, logprobs_non_neighbor, *repeated_non_neighbor_feats, opt=opt)
385444
self.done_beams = self.contrastive_beam_search(
386445
state, logprobs, *repeated_feats, opt=opt
387446
)

captioning/models/CaptionModel.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,10 @@ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
241241

242242
def contrastive_beam_search(self, init_state, init_logprobs, *args, **kwargs):
243243
# init_logprobs: batch_size*(num_distractors+1) x (vocab_size+1)
244-
# args: each tensor is (batch_size*beam_size) x per_image_view x ...
244+
# args: each tensor is (batch_size*beam_size*per_image_dim) x ...
245245

246-
# state: tuple of tensors (2 x batch_size*beam_size*per_image_view x d)
247-
# init_state: (2 x batch_size*1*per_image_view x d) [beam_size initially like 0; call this "this_beam_size" later]
246+
# state: tuple of tensors (2 x batch_size*beam_size*per_image_dim x d)
247+
# init_state: (2 x batch_size*1*per_image_dim x d) [beam_size initially like 0; call this "this_beam_size" later]
248248

249249
# does one step of classical beam search
250250

@@ -253,6 +253,8 @@ def contrastive_beam_search(self, init_state, init_logprobs, *args, **kwargs):
253253
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
254254
beam_size = opt.get('beam_size', 10)
255255
diversity_lambda = opt.get('diversity_lambda', 0.5)
256+
if (diversity_lambda != 0.0):
257+
raise NotImplementedError()
256258
decoding_constraint = opt.get('decoding_constraint', 0)
257259
remove_bad_endings = opt.get('remove_bad_endings', 0)
258260
suppress_UNK = opt.get('suppress_UNK', 0)
@@ -264,7 +266,7 @@ def contrastive_beam_search(self, init_state, init_logprobs, *args, **kwargs):
264266

265267
per_image_dim = (num_distractors+1)
266268
batch_size = init_logprobs.shape[0] // per_image_dim
267-
assert args[0].size(0) == batch_size * per_image_dim * beam_size
269+
assert args[0].size(0) == batch_size * beam_size * per_image_dim
268270

269271
V = init_logprobs.size(-1)
270272

@@ -311,7 +313,8 @@ def contrastive_beam_search(self, init_state, init_logprobs, *args, **kwargs):
311313

312314
logprobs = log_s1[:,0]
313315
if t == 0:
314-
assert torch.allclose(logprobs[:,0], logprobs[:,1])
316+
if beam_size > 1:
317+
assert torch.allclose(logprobs[:,0], logprobs[:,1])
315318
logprobs = logprobs[:,0]
316319
logprobs = logprobs.contiguous().view(-1, V)
317320

eval_literal.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ model_dir="models/updown"
55
split=$1
66
beam_size=$2
77

8-
id="literal_${split}_bs-${beam_size}"
8+
id="literal_${split}_bs-${beam_size}_dl-0.0"
99

1010
python -u tools/eval.py \
1111
--id $id \
@@ -20,4 +20,5 @@ python -u tools/eval.py \
2020
--infos_path ${model_dir}/infos_tds-best.pkl \
2121
--language_eval 1 \
2222
--beam_size $beam_size \
23+
--diversity_lambda 0.0 \
2324
| tee expts/${id}

0 commit comments

Comments
 (0)