Skip to content

Commit f416def

Browse files
committed
DOC Better documentation
1 parent 0331467 commit f416def

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

ch08/corrneighbours.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,35 @@
99
from all_correlations import all_correlations
1010
import numpy as np
1111
from load_ml100k import load
12-
def estimate_user(user, rest):
12+
13+
def estimate_user(user, rest, num_neigbors=100):
14+
'''Estimate ratings for user based on the binary rating matrix
15+
16+
Returns
17+
-------
18+
estimates: ndarray
19+
Returns a rating estimate for each movie
20+
'''
21+
22+
# Compute binary matrix correlations:
1323
bu = user > 0
1424
br = rest > 0
1525
ws = all_correlations(bu, br)
16-
selected = ws.argsort()[-100:]
26+
27+
# Select top `num_neigbors`:
28+
selected = ws.argsort()[-num_neigbors:]
29+
30+
# Use these to compute estimates:
1731
estimates = rest[selected].mean(0)
1832
estimates /= (.1 + br[selected].mean(0))
1933
return estimates
2034

2135

2236
def train_test(user, rest):
37+
'''Train & test on a single user
38+
39+
Returns both the prediction error and the null error
40+
'''
2341
estimates = estimate_user(user, rest)
2442
bu = user > 0
2543
br = rest > 0
@@ -49,7 +67,10 @@ def main():
4967
revs = (reviews > 0).sum(1)
5068
err = np.array(err)
5169
rmse = np.sqrt(err / revs[:, None])
70+
print("Average of RMSE / Null-model RMSE")
5271
print(np.mean(rmse, 0))
72+
print()
73+
print("Average of RMSE / Null-model RMSE (users with more than 60 reviewed movies)")
5374
print(np.mean(rmse[revs > 60], 0))
5475

5576
if __name__ == '__main__':

0 commit comments

Comments
 (0)