Skip to content

Commit c8e644a

Browse files
committed
ENH 20newsgroups example for FeatureHasher
1 parent cbb9b01 commit c8e644a

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Compares FeatureHasher and DictVectorizer by using both to vectorize
2+
text documents.
3+
4+
The example demonstrates syntax and speed only; it doesn't actually do
5+
anything useful with the extracted vectors. See the example scripts
6+
{document_classification_20newsgroups,clustering}.py for actual learning
7+
on text documents.
8+
9+
A discrepancy between the number of tokens reported for DictVectorizer and
10+
for FeatureHasher is to be expected due to hash collisions.
11+
"""
12+
13+
# Author: Lars Buitinck <[email protected]>
14+
# License: 3-clause BSD
15+
16+
from __future__ import print_function
17+
from collections import defaultdict
18+
import re
19+
import sys
20+
from time import time
21+
22+
import numpy as np
23+
24+
from sklearn.datasets import fetch_20newsgroups
25+
from sklearn.feature_extraction import DictVectorizer, FeatureHasher
26+
27+
28+
def n_nonzero_columns(X):
29+
"""Returns the number of non-zero columns in a CSR matrix X."""
30+
return len(np.unique(X.nonzero()[1]))
31+
32+
33+
def tokens(doc):
34+
"""Extract tokens from doc.
35+
36+
This uses a simple regex to break strings into tokens. For a more
37+
principled approach, see CountVectorizer or TfidfVectorizer.
38+
"""
39+
return (tok.lower() for tok in re.findall(r"\w+", doc))
40+
41+
42+
def token_freqs(doc):
43+
"""Extract a dict mapping tokens from doc to their frequencies."""
44+
freq = defaultdict(int)
45+
for tok in tokens(doc):
46+
freq[tok] += 1
47+
return freq
48+
49+
50+
categories = [
51+
'alt.atheism',
52+
'talk.religion.misc',
53+
'comp.graphics',
54+
'sci.space',
55+
]
56+
# Uncomment the following line to use a larger set (11k+ documents)
57+
#categories=None
58+
59+
print(__doc__)
60+
print("Usage: %s [n_features_for_hashing]" % sys.argv[0])
61+
print(" The default number of features is 2**18.")
62+
print()
63+
64+
try:
65+
n_features = int(sys.argv[1])
66+
except IndexError:
67+
n_features = 2 ** 18
68+
except ValueError:
69+
print("not a valid number of features: %r" % sys.argv[1])
70+
sys.exit(1)
71+
72+
print("Loading 20 newsgroups training data")
73+
raw_data = fetch_20newsgroups(subset='train', categories=categories).data
74+
print("%d documents" % len(raw_data))
75+
print()
76+
77+
print("DictVectorizer")
78+
t0 = time()
79+
vectorizer = DictVectorizer()
80+
vectorizer.fit_transform(token_freqs(d) for d in raw_data)
81+
print("done in %fs" % (time() - t0))
82+
print("Found %d unique terms" % len(vectorizer.get_feature_names()))
83+
print()
84+
85+
print("FeatureHasher on frequency dicts")
86+
t0 = time()
87+
hasher = FeatureHasher(n_features=n_features)
88+
X = hasher.transform(token_freqs(d).iteritems() for d in raw_data)
89+
print("done in %fs" % (time() - t0))
90+
print("Found %d unique terms" % n_nonzero_columns(X))
91+
print()
92+
93+
print("FeatureHasher on raw tokens")
94+
t0 = time()
95+
hasher = FeatureHasher(n_features=n_features, input_type="strings")
96+
X = hasher.transform(tokens(d) for d in raw_data)
97+
print("done in %fs" % (time() - t0))
98+
print("Found %d unique terms" % n_nonzero_columns(X))

0 commit comments

Comments
 (0)