Skip to content

Commit b748b51

Browse files
committed
added kl divergences as loss.
1 parent 60fff4e commit b748b51

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

patchdata/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def feat_std1_store(store, to_div, chunk=512, exclude=[None], cache=False):
617617

618618
print "No cache, writing to", name
619619
fstd1 = h5py.File(name, 'w')
620-
helpers.feat_div(store, fstd1, chunk=chunk, div=to_div)
620+
helpers.feat_div(store, fstd1, chunk=chunk, div=to_div, exclude=exclude)
621621
fstd1.attrs["Feat_std1"] = "from " + str(store.filename)
622622
return fstd1
623623

patchdata/evaluate.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,47 @@ def prod(v1, v2):
6868
return (n - np.sum(v1*v2))/2.
6969

7070

71+
def kl_g_g(v1, v2):
72+
"""
73+
Kl divergence between v1 and v2 gaussians: use jensen shannon for gaussians.
74+
"""
75+
_, d = v1.shape
76+
d = d//2
77+
v1_m = v1[:, :d]
78+
v1_lv = v1[:, d:] # log_var
79+
v2_m = v2[:, :d]
80+
v2_lv = v2[:, d:] # log_var
81+
82+
v1_v = np.exp(v1_lv) # var
83+
v2_v = np.exp(v2_lv) # var
84+
85+
# log(sig/sig) cancels if we add up
86+
# first part: kl(v1, v2), but without part that cancels!
87+
klv1v2 = (v1_v + (v1_m-v2_m)**2)/(2*v2_v + eps) - 0.5
88+
klv1v2 = klv1v2.sum()
89+
klv2v1 = (v2_v + (v1_m-v2_m)**2)/(2*v1_v + eps) - 0.5
90+
klv2v1 = klv2v1.sum()
91+
return klv1v2+klv2v1
92+
93+
94+
def kl_g_01(v1, v2):
95+
"""
96+
special case: v1 and v2 are handled by multiview model.
97+
In this case, v1 == v2 (the same latent representation).
98+
Only use v1 to compute the KL divergence to 0/1 Gaussian.
99+
"""
100+
_, d = v1.shape
101+
d = d//2
102+
mean = v1[:, :d]
103+
log_var = v1[:, d:]
104+
var = np.exp(log_var)
105+
106+
kl = (mean**2 + var - log_var - 1)
107+
kl = kl.sum()
108+
kl = kl/2.
109+
return kl
110+
111+
71112
_dist_table = {
72113
"L2": l2_dist
73114
,"L2H": l2_dist_half
@@ -77,6 +118,8 @@ def prod(v1, v2):
77118
,"CHI": chi_dist
78119
,"JSD": jsd
79120
,"PRD": prod
121+
,"KLG01": kl_g_01
122+
,"KL_G_G": kl_g_g
80123
}
81124

82125

0 commit comments

Comments
 (0)