diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 09a89ac236bd..57fa54e3a9f5 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -1,5 +1,5 @@ """ -Apriori Algorithm is a Association rule mining technique, also known as market basket +Apriori Algorithm is an Association rule mining technique, also known as market basket analysis, aims to discover interesting relationships or associations among a set of items in a transactional or relational database. @@ -11,6 +11,7 @@ Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining """ +from collections import defaultdict from itertools import combinations @@ -24,90 +25,112 @@ def load_data() -> list[list[str]]: return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]] -def prune(itemset: list, candidates: list, length: int) -> list: +def prune( + frequent_itemsets: list[list[str]], candidates: list[list[str]] +) -> list[list[str]]: """ - Prune candidate itemsets that are not frequent. - The goal of pruning is to filter out candidate itemsets that are not frequent. This - is done by checking if all the (k-1) subsets of a candidate itemset are present in - the frequent itemsets of the previous iteration (valid subsequences of the frequent - itemsets from the previous iteration). - - Prunes candidate itemsets that are not frequent. - - >>> itemset = ['X', 'Y', 'Z'] - >>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - >>> prune(itemset, candidates, 2) - [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - - >>> itemset = ['1', '2', '3', '4'] - >>> candidates = ['1', '2', '4'] - >>> prune(itemset, candidates, 3) - [] + Prunes candidate itemsets by ensuring all (k-1)-subsets exist in + previous frequent itemsets. + + >>> frequent_itemsets = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] + >>> candidates = [['X', 'Y', 'Z'], ['X', 'Y', 'W']] + >>> prune(frequent_itemsets, candidates) + [['X', 'Y', 'Z']] """ - pruned = [] + + previous_frequents = {frozenset(itemset) for itemset in frequent_itemsets} + + pruned_candidates = [] for candidate in candidates: - is_subsequence = True - for item in candidate: - if item not in itemset or itemset.count(item) < length - 1: - is_subsequence = False - break - if is_subsequence: - pruned.append(candidate) - return pruned + all_subsets_frequent = all( + frozenset(subset) in previous_frequents + for subset in combinations(candidate, len(candidate) - 1) + ) + if all_subsets_frequent: + pruned_candidates.append(candidate) + + return pruned_candidates def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]: """ Returns a list of frequent itemsets and their support counts. - >>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']] + >>> data = [['A', 'B'], ['A', 'C'], ['B', 'C']] >>> apriori(data, 2) - [(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)] + [(['A'], 2), (['B'], 2), (['C'], 2)] >>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']] >>> apriori(data, 3) - [] + [(['1'], 4), (['2'], 3), (['3'], 3)] """ - itemset = [list(transaction) for transaction in data] - frequent_itemsets = [] - length = 1 - while itemset: - # Count itemset support - counts = [0] * len(itemset) - for transaction in data: - for j, candidate in enumerate(itemset): - if all(item in transaction for item in candidate): - counts[j] += 1 + item_counts: defaultdict[str, int] = defaultdict(int) + for transaction in data: + for item in transaction: + item_counts[item] += 1 + + current_frequents = [ + [item] for item, count in item_counts.items() if count >= min_support + ] + frequent_itemsets = [ + ([item], count) for item, count in item_counts.items() if count >= min_support + ] - # Prune infrequent itemsets - itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support] + k = 2 + while current_frequents: + candidates = [ + sorted(set(i) | set(j)) + for i in current_frequents + for j in current_frequents + if len(set(i).union(j)) == k + ] - # Append frequent itemsets (as a list to maintain order) - for i, item in enumerate(itemset): - frequent_itemsets.append((sorted(item), counts[i])) + candidates = [list(c) for c in {frozenset(c) for c in candidates}] - length += 1 - itemset = prune(itemset, list(combinations(itemset, length)), length) + candidates = prune(current_frequents, candidates) - return frequent_itemsets + candidate_counts: defaultdict[tuple[str, ...], int] = defaultdict(int) + for transaction in data: + t_set = set(transaction) + for candidate in candidates: + if set(candidate).issubset(t_set): + candidate_counts[tuple(sorted(candidate))] += 1 + + current_frequents = [ + list(key) for key, count in candidate_counts.items() if count >= min_support + ] + frequent_itemsets.extend( + [ + (list(key), count) + for key, count in candidate_counts.items() + if count >= min_support + ] + ) + + k += 1 + + return sorted(frequent_itemsets, key=lambda x: (len(x[0]), x[0])) if __name__ == "__main__": """ Apriori algorithm for finding frequent itemsets. - Args: - data: A list of transactions, where each transaction is a list of items. - min_support: The minimum support threshold for frequent itemsets. + This script loads sample transaction data and runs the Apriori algorithm + with a user-defined minimum support threshold. - Returns: - A list of frequent itemsets along with their support counts. + The result is a list of frequent itemsets along with their support counts. """ import doctest doctest.testmod() - # user-defined threshold or minimum support level - frequent_itemsets = apriori(data=load_data(), min_support=2) - print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets)) + transactions = load_data() + min_support_threshold = 2 + + frequent_itemsets = apriori(transactions, min_support=min_support_threshold) + + print("Frequent Itemsets:") + for itemset, support in frequent_itemsets: + print(f"{itemset}: {support}")