Demo entry 6682397

Apriori

   

Submitted by anonymous on Dec 08, 2017 at 11:09
Language: Python. Code size: 3.7 kB.

'''
用一个数组(list)存储k-项集,用一个映射(dict)存储对应项集的支持度
存储k-项集的列表中表示项集的数据结构是集合
在词典中存储支持度的时候,key是项集,value是支持度,存储的项集的数据结构是元组(turple),因为dict的key只能是不可变元素
'''

# 读取数据
# path:数据文件存储的路径
# result:一个ip集合的列表,数据中的每一个解析路径中的ip都被提取出来,作为集合中的元素
def getTotalData(path):
    import os
    result=list()
    files=os.listdir(path)
    for file_name in files:
        print("read data "+file_name)
        file=open(path+"\\"+file_name, 'r')
        for line in file:
            ss=line[:-1].split("|")
            tempSet=set()
            tempSet.add(ss[1])
            for i in range(3, len(ss)-1):
                tempSet.add(ss[i])
            result.append(tempSet)
    return result

# 初始化,扫描数据库,得到所有的一元项集
# inputData:读入的原始数据 [{ip}]
# d:项集-支持度映射的dict
def init(inputData):
    d=dict() # 一元项集的支持度
    for i in inputData:
        for j in i:
            tempSet=set()
            tempSet.add(j)
            if tuple(tempSet) in d:
                d[tuple(tempSet)]+=1
            else:
                d[tuple(tempSet)]=1
    return d

# 删掉支持度小于阈值的k项集
# d:项集支持度映射(dict) supp:支持度阈值
# l:满足最小支持度的k-频繁集
def dropItemSet(d, supp):
    l=list()
    for t in d: # 查看每一个项集的支持度
        if d[t]>=supp:
            l.append(set(t))
    return l

# 产生k+1项集,如果没有k+1项集,那么算法停止
# tempList:k频繁集
def merge(tempList):
    l=list() # k+1项集
    for i in range(len(tempList)):
        for j in range(i+1, len(tempList)):
            # 两两组合k频繁集,得到k+1项集
            if len(tempList[i]-tempList[j])==1 and len(tempList[j]-tempList[i])==1:
                # 被merge的k频繁集只能有一个不一样的元素
                tempSet=tempList[i]|tempList[j] # 集合取并
                l.append(tempSet)
    return l

# 遍历数据库,算每个项集的支持度
# DB:存储的原始数据 l:k项集
# d:k项集的支持度映射(dict)
def scanData(DB, l):
    d=dict() # 存储k项集支持度的dict
    for i in l: # 初始化d
        d[tuple(i)]=0
    for i in DB: # 扫描数据库
        for j in l: # 扫描k项集,耗时间
            if j<=i: # j是i的子集
                d[tuple(j)]+=1
    return d

# 由频繁集得到关联规则
# l: k频繁集 DB:数据 conf:置信度
def getRules(l, DB, conf):
    tempResult=list()
    result=list() # 关联规则结果
    d=dict() # 关联规则的置信度
    for k in l:
        # 遍历频繁集,生成关联规则
        # !!!!!【关联规则不全】
        k=list(k)
        for i in range(len(k)):
            for j in range(len(k)):
                if i!=j:
                    t=set()
                    t.add(k[i])
                    t.add(k[j])
                    tempResult.append(t)
                    d[tuple(t)]=0
    for i in DB:
        # 遍历数据库,计算置信度
        for j in tempResult:
            if j<i:
                d[tuple(j)]+=1
    for t in d:
        # 大于置信度阈值的关联规则返回
        if d[t]>conf:
            result.append(set(t))
    return result

# apriori算法
def apriori(DB, conf, supp):
    frequentSet_supp_dict=init(DB)  # 初始化一元项集
    frequentSet_list=dropItemSet(frequentSet_supp_dict, supp*len(DB)) # 删除小于支持度的项集
    k=1
    while True:
        Ck=frequentSet_list.copy()
        print("%d频繁集的个数:%d"%(k, len(frequentSet_list)))
        frequentSet_list=merge(frequentSet_list)  # 产生k+1项集
        if len(frequentSet_list)==0:  # 判断是否产生了k+1项集
            break
        frequentSet_supp_dict=scanData(DB, frequentSet_list)  # 遍历数据库,计算支持度
        frequentSet_list=dropItemSet(frequentSet_supp_dict, supp*len(data))  # 删除支持度小于阈值的项集
        if len(frequentSet_list)==0: # 判断是不是有k+1频繁集
            break
        k+=1

    return Ck
    # 用k频繁集生成关联规则
    # result=getRules(Ck, DB, conf*len(DB))
    # return result

data=getTotalData("data") # 读入数据
print("dataLength:"+str(len(data)))
rules=apriori(data, conf=0.01, supp=0.01)
for i in rules:
    i=list(i)
    print(i[0]+"->"+i[1])
    print(i[1]+"->"+i[0])

This snippet took 0.01 seconds to highlight.

Back to the Entry List or Home.

Delete this entry (admin only).