Skip to content

Commit b833f3e

Browse files
authored
Add files via upload
1 parent 1669121 commit b833f3e

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

TDRC.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import tensorly as tl
3+
4+
5+
def TDRC(X, S_d,S_m, r=4, alpha=0.125, beta=0.25, lam=0.001, tol=1e-6, max_iter=500):
6+
7+
m = X.shape[0]
8+
d = X.shape[1]
9+
t = X.shape[2]
10+
11+
# initialization
12+
rho_1 = 1
13+
rho_2 = 1
14+
np.random.seed(0)
15+
C = np.mat(np.random.rand(m, r))
16+
P = np.mat(np.random.rand(d, r))
17+
D = np.mat(np.random.rand(t, r))
18+
19+
Y_1 = 0
20+
Y_2 = 0
21+
22+
X_1 = np.mat(tl.unfold(X, 0))
23+
X_2 = np.mat(tl.unfold(X, 1))
24+
X_3 = np.mat(tl.unfold(X, 2))
25+
26+
for i in range(max_iter):
27+
G = np.mat(tl.tenalg.khatri_rao([P, D]))
28+
output_X_old = tl.fold(np.array(C * G.T), 0, X.shape)
29+
30+
O_1 = C.T * C
31+
O_2 = P.T * P
32+
33+
M_2 = CG(0, alpha * O_1, O_1, alpha * C.T * S_m * C, lam, 0.01, 200)
34+
35+
M_3 = CG(0, beta * O_2, O_2, beta * P.T * S_d * P, lam, 0.01, 200)
36+
37+
K = np.mat(np.eye(r))
38+
39+
F = C * M_2
40+
J = (alpha * S_m.T * F + rho_1 * C + Y_1) * np.linalg.inv(alpha * F.T * F + rho_1 * np.eye(r))
41+
Q = M_2 * J.T
42+
C = (X_1 * G + alpha * S_m * Q.T + rho_1 * J - Y_1) * np.linalg.inv(
43+
G.T * G + alpha * Q * Q.T + rho_1 * np.eye(r))
44+
45+
R = P * M_3
46+
W = (beta * S_d.T * R + rho_2 * P + Y_2) * np.linalg.inv(beta * R.T * R + rho_2 * np.eye(r))
47+
Y_1 = Y_1 + rho_1 * (C - J)
48+
rho_1 = rho_1 * 1.1
49+
50+
51+
U = np.mat(tl.tenalg.khatri_rao([C, D]))
52+
Z = M_3 * W.T
53+
P = (X_2 * U + beta * S_d * Z.T + rho_2 * W - Y_2) * np.linalg.inv(U.T * U + beta * Z * Z.T + rho_2 * np.eye(r))
54+
Y_2 = Y_2 + rho_2 * (P - W)
55+
rho_2 = rho_2 * 1.1
56+
57+
58+
B = np.mat(tl.tenalg.khatri_rao([C, P]))
59+
D = X_3 * B * np.linalg.inv(B.T * B + lam * np.eye(r))
60+
61+
output_X = tl.fold(np.array(D * B.T), 2, X.shape)
62+
err = np.linalg.norm(output_X - output_X_old) / np.linalg.norm(output_X_old)
63+
# print(err)
64+
if err < tol:
65+
# print(i)
66+
break
67+
68+
predict_X = np.array(tl.fold(np.array(C * np.mat(tl.tenalg.khatri_rao([P, D])).T), 0, X.shape))
69+
70+
return predict_X
71+
72+
73+
def CG(X_initial, A, B, D, mu, tol, max_iter):
74+
75+
X = X_initial
76+
R = D - A * X * B - mu * X
77+
P = np.array(R, copy=True)
78+
79+
for i in range(max_iter):
80+
R_norm = np.trace(R * R.T)
81+
Q = A * P * B + mu * P
82+
alpha = R_norm / np.trace(Q * P.T)
83+
X = X + alpha * P
84+
R = R - alpha * Q
85+
err = np.linalg.norm(R)
86+
if err < tol:
87+
88+
print("CG convergence: iter = %d" % i)
89+
90+
beta = np.trace(R * R.T) / R_norm
91+
P = R + beta * P
92+
93+
return X
94+

0 commit comments

Comments
 (0)