Skip to content

Refactor Apriori Algorithm: Fix Pruning Logic, Add Type Hints, and Pass Lint Checks #12700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
137 changes: 80 additions & 57 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Apriori Algorithm is a Association rule mining technique, also known as market basket
Apriori Algorithm is an Association rule mining technique, also known as market basket
analysis, aims to discover interesting relationships or associations among a set of
items in a transactional or relational database.

Expand All @@ -11,6 +11,7 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from collections import defaultdict
from itertools import combinations


Expand All @@ -24,90 +25,112 @@ def load_data() -> list[list[str]]:
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
def prune(
frequent_itemsets: list[list[str]], candidates: list[list[str]]
) -> list[list[str]]:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
Prunes candidate itemsets by ensuring all (k-1)-subsets exist in
previous frequent itemsets.

>>> frequent_itemsets = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> candidates = [['X', 'Y', 'Z'], ['X', 'Y', 'W']]
>>> prune(frequent_itemsets, candidates)
[['X', 'Y', 'Z']]
"""
pruned = []

previous_frequents = {frozenset(itemset) for itemset in frequent_itemsets}

pruned_candidates = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
if item not in itemset or itemset.count(item) < length - 1:
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned
all_subsets_frequent = all(
frozenset(subset) in previous_frequents
for subset in combinations(candidate, len(candidate) - 1)
)
if all_subsets_frequent:
pruned_candidates.append(candidate)

return pruned_candidates


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
"""
Returns a list of frequent itemsets and their support counts.

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> data = [['A', 'B'], ['A', 'C'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
[(['A'], 2), (['B'], 2), (['C'], 2)]

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
[(['1'], 4), (['2'], 3), (['3'], 3)]
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1
item_counts: defaultdict[str, int] = defaultdict(int)
for transaction in data:
for item in transaction:
item_counts[item] += 1

current_frequents = [
[item] for item, count in item_counts.items() if count >= min_support
]
frequent_itemsets = [
([item], count) for item, count in item_counts.items() if count >= min_support
]

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
k = 2
while current_frequents:
candidates = [
sorted(set(i) | set(j))
for i in current_frequents
for j in current_frequents
if len(set(i).union(j)) == k
]

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))
candidates = [list(c) for c in {frozenset(c) for c in candidates}]

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
candidates = prune(current_frequents, candidates)

return frequent_itemsets
candidate_counts: defaultdict[tuple[str, ...], int] = defaultdict(int)
for transaction in data:
t_set = set(transaction)
for candidate in candidates:
if set(candidate).issubset(t_set):
candidate_counts[tuple(sorted(candidate))] += 1

current_frequents = [
list(key) for key, count in candidate_counts.items() if count >= min_support
]
frequent_itemsets.extend(
[
(list(key), count)
for key, count in candidate_counts.items()
if count >= min_support
]
)

k += 1

return sorted(frequent_itemsets, key=lambda x: (len(x[0]), x[0]))


if __name__ == "__main__":
"""
Apriori algorithm for finding frequent itemsets.

Args:
data: A list of transactions, where each transaction is a list of items.
min_support: The minimum support threshold for frequent itemsets.
This script loads sample transaction data and runs the Apriori algorithm
with a user-defined minimum support threshold.

Returns:
A list of frequent itemsets along with their support counts.
The result is a list of frequent itemsets along with their support counts.
"""
import doctest

doctest.testmod()

# user-defined threshold or minimum support level
frequent_itemsets = apriori(data=load_data(), min_support=2)
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
transactions = load_data()
min_support_threshold = 2

frequent_itemsets = apriori(transactions, min_support=min_support_threshold)

print("Frequent Itemsets:")
for itemset, support in frequent_itemsets:
print(f"{itemset}: {support}")