6
\$\begingroup\$

This takes in a dataset, the minimum support and the minimum confidence values as its options, and returns the association rules.

I'm looking for pointers towards better optimization, documentation and code quality.

"""
Description : A Python implementation of the Apriori Algorithm
Usage:
 $python apriori.py -f DATASET.csv -s minSupport -c minConfidence
 $python apriori.py -f DATASET.csv -s 0.15 -c 0.6
"""
import sys
from itertools import chain, combinations
from collections import defaultdict
from optparse import OptionParser
def subsets(arr):
 """ 
 Returns non empty subsets of arr
 enumerate(arr) <= returns the following format "<index>, <array element>"
 combinations(arr, i) <= returns all i-length combinations of the array.
 chain(arr) <= unpackas a list of lists
 """
 return chain(*[combinations(arr, i + 1) for i, a in enumerate(arr)])
def returnItemsWithMinSupport(itemSet, transactionList, minSupport, freqSet):
 """calculates the support for items in the itemSet and returns a subset
 of the itemSet each of whose elements satisfies the minimum support
 """
 _itemSet = set()
 localSet = defaultdict(int)
 for item in itemSet:
 for transaction in transactionList:
 if item.issubset(transaction):
 freqSet[item] += 1
 localSet[item] += 1
 for item, count in localSet.items():
 support = float(count)/len(transactionList)
 if support >= minSupport:
 _itemSet.add(item)
 return _itemSet
def joinSet(itemSet, length):
 """Join a set with itself and returns the n-element itemsets"""
 return set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length])
def getItemSetTransactionList(data_iterator):
 """
 Takes data from dataFromFile() and returns list of items and a list of transactions
 and generate two seperate sets of items and transactions.
 The item list would be: 
 ([frozenset(['apple']), frozenset(['beer']), frozenset(['chicken']), etc
 The transaction list would be:
 frozenset(['beer', 'rice', 'apple', 'chicken']), frozenset(['beer', 'rice', 'apple']), etc
 """
 transactionList = list()
 itemSet = set()
 for record in data_iterator:
 transaction = frozenset(record)
 transactionList.append(transaction)
 for item in transaction:
 itemSet.add(frozenset([item])) # Generate 1-itemSets
 return itemSet, transactionList
def runApriori(data_iter, minSupport, minConfidence):
 """
 run the apriori algorithm. data_iter is a record iterator
 Return both:
 - items (tuple, support)
 - rules ((pretuple, posttuple), confidence)
 """
 itemSet, transactionList = getItemSetTransactionList(data_iter)
 freqSet = defaultdict(int)
 largeSet = dict()
 # Global dictionary which stores (key=n-itemSets,value=support)
 # which satisfy minSupport
 assocRules = dict()
 # Dictionary which stores Association Rules
 oneCSet = returnItemsWithMinSupport(itemSet,
 transactionList,
 minSupport,
 freqSet)
 currentLSet = oneCSet
 k = 2
 while(currentLSet != set([])):
 largeSet[k-1] = currentLSet
 currentLSet = joinSet(currentLSet, k)
 currentCSet = returnItemsWithMinSupport(currentLSet,
 transactionList,
 minSupport,
 freqSet)
 currentLSet = currentCSet
 k = k + 1
 def getSupport(item):
 """local function which Returns the support of an item"""
 return float(freqSet[item])/len(transactionList)
 toRetItems = []
 for key, value in largeSet.items():
 toRetItems.extend([(tuple(item), getSupport(item))
 for item in value])
 toRetRules = []
 for key, value in largeSet.items()[1:]:
 for item in value:
 _subsets = map(frozenset, [x for x in subsets(item)])
 for element in _subsets:
 remain = item.difference(element)
 if len(remain) > 0:
 confidence = getSupport(item)/getSupport(element)
 if confidence >= minConfidence:
 toRetRules.append(((tuple(element), tuple(remain)),
 confidence))
 return toRetItems, toRetRules
def printResults(items, rules):
 """prints the generated itemsets sorted by support and the confidence rules sorted by confidence"""
 for item, support in sorted(items, key=lambda (item, support): support):
 print "item: %s , %.3f" % (str(item), support)
 print "\n------------------------ RULES:"
 for rule, confidence in sorted(rules, key=lambda (rule, confidence): confidence):
 pre, post = rule
 print "Rule: %s ==> %s , %.3f" % (str(pre), str(post), confidence)
def dataFromFile(fname):
 """
 Function which reads from the file and yields a generator of frozen sets of each line in the csv
 The first line of tesco.csv file returns the following output:
 frozenset(['beer', 'rice', 'apple', 'chicken'])
 """
 file_iter = open(fname, 'rU')
 for line in file_iter:
 line = line.strip().rstrip(',') # Remove trailing comma
 record = frozenset(line.split(','))
 yield record
if __name__ == "__main__":
 optparser = OptionParser()
 optparser.add_option('-f', '--inputFile',
 dest='input',
 help='filename containing csv',
 default=None)
 optparser.add_option('-s', '--minSupport',
 dest='minS',
 help='minimum support value',
 default=0.15,
 type='float')
 optparser.add_option('-c', '--minConfidence',
 dest='minC',
 help='minimum confidence value',
 default=0.6,
 type='float')
 (options, args) = optparser.parse_args()
 inFile = None
 if options.input is None:
 inFile = sys.stdin
 elif options.input is not None:
 inFile = dataFromFile(options.input)
 else:
 print 'No dataset filename specified, system with exit\n'
 sys.exit('System will exit')
 minSupport = options.minS
 minConfidence = options.minC
 items, rules = runApriori(inFile, minSupport, minConfidence)
 printResults(items, rules)

Same data is the following csv file:

apple,beer,beer,rice,chicken
apple,beer,beer,rice
apple,beer,beer
apple,mango
milk,beer,beer,rice,chicken
milk,beer,rice
milk,beer
milk,mango
asked Jul 28, 2016 at 10:40
\$\endgroup\$
6
  • 3
    \$\begingroup\$ Brace yourselves, antiCamelCaseRecommendationsAreComing :) \$\endgroup\$ Commented Jul 28, 2016 at 11:05
  • \$\begingroup\$ It would be really useful if you provided a sample of the data you are reading from. \$\endgroup\$ Commented Jul 28, 2016 at 13:04
  • \$\begingroup\$ @OscarSmith Added the sample data (csv file in the question) :) \$\endgroup\$ Commented Jul 28, 2016 at 14:33
  • 1
    \$\begingroup\$ @MathiasEttinger No. Thanks for pointing it out. Would use argparse then :) \$\endgroup\$ Commented Jul 28, 2016 at 16:46
  • 1
    \$\begingroup\$ @Dawny33 As regard to your module docstring, I can also suggest docopt \$\endgroup\$ Commented Jul 28, 2016 at 16:48

1 Answer 1

5
\$\begingroup\$

My biggest piece of advice would be to replace freqSet = defaultdict(int) with a Counter. Counters are a datatype designed to do exactly what you are doing with defaultdicts, and they have some specialized methods.

for item in itemSet:
 for transaction in transactionList:
 if item.issubset(transaction):
 freqSet[item] += 1

Could be replaced with

freqSet.update(item for item in itemSet for transaction in TransactionList if item.issubset(transaction))

This should be a pretty big speed increase. Also, set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length]) could be written using a set comprehension, which would lower memory usage, and increase speed.

Dan Oberlam
8,0492 gold badges33 silver badges74 bronze badges
answered Jul 28, 2016 at 12:46
\$\endgroup\$
2
  • \$\begingroup\$ I really should learn how to spell comprehension. \$\endgroup\$ Commented Jul 28, 2016 at 14:41
  • \$\begingroup\$ Possible to explain how do I write the line set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length]) as a set comprehension? And the freqSet.update() is returning some weird error :( \$\endgroup\$ Commented Aug 1, 2016 at 8:37

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.