回顾Aprioir算法

Aprioir算法

Apriori算法是经典的数据关联规则的算法,接下来让我们一起回顾Apriori算法吧!

1. 算法背景

在沃尔玛超市中,商店人员发现购买尿布的人往往也会购买啤酒,这两种毫无关系的商品为什么会被一起购买呢?

沃尔玛经过分析后发现,美国的家庭主妇们经常会让她们的丈夫在回家的路上顺道买一些尿布给孩子用,而这些丈夫们辛苦工作了一天也想犒劳一下自己,于是在买尿布之余也给捎带上了自己最爱的啤酒。这就是著名的啤酒与尿布的故事。

研究“啤酒与尿布”的关联分析方法就是购物篮分析,购物篮分析可以帮助商店在销售过程中,找到具有关联关系的商品,以此获得销售量的增长。

Apriori算法就是用来挖掘关联规则的最经典的算法。

2. 算法介绍

在介绍算法之前,首先介绍几个概念

  1. 支持度(Support)

$$ Support(X, Y) = \frac{num(XY)}{num(All)} = \frac{含有X,Y物品的数据数量}{数据总数} $$

通过支持度我们可以计算频繁项目集。

  1. 置信度(confidence)

$$ Conf(X => Y) = P(Y | X) = \frac{XY}{X}$$

形象的可以理解为购买了商品X之后,购买商品Y的概率。

通过置信度我们可以来计算关联规则。

  1. 频繁k项目集

满足支持度的商品的集合,这个集合中正好有k件商品

3. 算法示例

Apriori算法的目的就是最大频繁项目集,举个例子,如果我们有下面的购物记录。

ID 购买商品
1 棒棒糖,啤酒,雪碧
2 啤酒,尿布,可乐
3 棒棒糖,尿布,啤酒,可乐
4 尿布,可乐

我们假设支持度为50%。

我们取所有商品的种类可以得到一阶频繁项目集候选集合,一阶频繁项目集候选集合如下所示:

[['棒棒糖'], ['可乐'], ['尿布'], ['啤酒'], ['雪碧']]

接下来我们统计所有候选集合出现的次数:

购买商品 出现次数
棒棒糖 2
可乐 3
尿布 3
啤酒 3
雪碧 1

我们删去小于支持度的项目集合,由于支持度为50%,所以我们删除出现次数小于向下取整[数据总数×支持度]=int(5*0.5)=2的数据,最后我们可以得到我们的频繁1项目集,如下所示:

购买商品 出现次数
棒棒糖 2
可乐 3
尿布 3
啤酒 3

然后我们生成频繁2项集合的候选集,生成的方式很简单,就是将1阶频繁项目集合的商品两两合并即可,频繁2项集候选集如下所示。

[['可乐', '尿布'], ['可乐', '棒棒糖'], ['可乐', '啤酒'], ['尿布', '棒棒 糖'], ['啤酒', '尿布'], ['啤酒', '棒棒糖']]

接下来我们统计频繁2项集候选集中每个商品组合出现的次数,结果如下所示:

购买商品 出现次数
可乐,尿布 3
可乐,棒棒糖 1
可乐,啤酒 2
尿布,棒棒糖 1
啤酒,尿布 2
啤酒,棒棒糖 2

我们删去小于支持度的项目集合,由于支持度为50%,所以我们删除出现次数2的数据,最后我们可以得到我们的频繁2项目集,如下所示:

购买商品 出现次数
可乐,尿布 3
可乐,啤酒 2
啤酒,尿布 2
啤酒,棒棒糖 2

然后我们生成频繁3项集合的候选集,就是将2阶频繁项目集合的商品两两合并即可,频繁3项集候选集如下所示。

Tip

这里需要保证两个2阶频繁项目集合的商品,合成之后刚好为3个商品。

例如[可乐,尿布]和[可乐,啤酒]可以合成为[可乐,尿布,,啤酒]

但是[可乐,尿布]和[啤酒,棒棒糖]不可以合成为[可乐,尿布,啤酒,棒棒糖]

所以频繁3项集合的候选集有

[['可乐', '啤酒', '尿布'], ['可乐', '啤酒', '棒棒糖'], ['啤酒', '尿布', '棒棒糖']]

接下来我们统计频繁3项集候选集中每个商品组合出现的次数,结果如下所示:

购买商品 出现次数
可乐,啤酒,尿布 2
可乐,啤酒,棒棒糖 1
啤酒,尿布,棒棒糖 1

通过支持度过滤,最后可得频繁3项集合为:

购买商品 出现次数
可乐,啤酒,尿布 2

4. 算法核心

  1. 任何一个频繁项集的子集必是频繁项集

{西瓜、冬瓜、南瓜}是频繁项集 $=\gt$ {西瓜、冬瓜}是频繁项集

  1. 如果一个集合是不频繁的,其超集必然是不频繁的

{牛奶}不是频繁项集 $=\gt$ {牛奶、鸡蛋}不是频繁项集

5. 算法实现

算法流程

  1. 初始化i = 1
  2. 计算数据的频繁i项集的候选集
  3. 计算候选集出现的次数
  4. 通过支持度得到频繁i项集,i++
  5. 如果频繁i项集为空,算法结束,否则重复2-5步

本次算法实现采用Python和Go语言分别进行实现。(由于笔者最近在学习Go语言,所以也使用Go语言进行实现了一遍)

5.1 Python实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
## 得到所有种类
def get_set(data):
    s = set()
    for row in data:
        s = s | row
    return [[k] for k in s]

## 判断项集是否在数据库data中
def isin(items, row):
    for x in items:
        if x not in row:
            return False
    return True

## 发现频繁项目集
def find_frequent(data, compose_data):
    res = {}
    ## 初始化频繁项集
    for t in compose_data:
        res[",".join(t)] = 0
    ## 项集去重

    ## 遍历数据库计算所有项集出现的次数
    for row in data:
        for items in compose_data:
            if isin(items,row):
                res[",".join(items)] += 1

    print(res)
    ## 删除小于支持度的项集
    for k in list(res.keys()):
        if res[k] < conf_length:
            del res[k]
    return res

def merge(a, b):
    res = {}
    return_list = []
    for x in a:
        res[x] = True
    for x in b:
        res[x] = True
    for k in res.keys():
        return_list.append(k)
    sorted(return_list)
    return return_list

## 交叉生成下一阶段备选的频繁项集
def compose(set_data,level = 2):
    set_key_data = [k.split(",") for k in set_data.keys()]
    compose_data = list()
    for i in range(len(set_key_data)):
        for j in range(i+1,len(set_key_data)):
            tmp = merge(set_key_data[i],set_key_data[j])
            if len(tmp) == level:
                tmp.sort()
                compose_data.append(tmp)
                
    ## 去重复
    res = {}
    re_compose_data = []
    for x in compose_data:
        res[",".join(x)] = True
    compose_data = [k.split(",") for k in res.keys()]
    return list(compose_data)


## 数据库数据
data = [
    {"棒棒糖","啤酒","雪碧"},
    {"尿布","啤酒","可乐"},
    {"棒棒糖","尿布","啤酒","可乐"},
    {"尿布","可乐"},
]

## 支持度
CONF = 0.5

conf_length = int(len(data) * CONF)
compose_data = get_set(data)
i = 2
while compose_data:
    print(compose_data)
    frequent = find_frequent(data,compose_data)
    print(f"第{i-1}频繁项集为:",frequent)
    compose_data = compose(frequent,i)
    i+=1

5.2 Go语言实现

package main

import (
    "fmt"
    "strings"
    "sort"
)

func get_set(data [][]string) [][]string {
    set := make(map[string]bool)
    for _, row := range data {
        for _, x := range row {
            set[x] = true
        }
    }
    var return_list [][]string
    for index, _ := range set {
        var index_list []string
        index_list = append(index_list,index)
        return_list = append(return_list,index_list)
    }
    return return_list
}


func isin(items []string, row []string) bool {
    // items 一个项集,row 一行数据
    for _, x := range items{
        flag := 1
        for _, x_row := range row{
            if x == x_row{
                flag = 0
                break
            }
        }
        if flag == 1 {
            return false
        }
    }
    return true
}

 //发现频繁项集
func find_frequent(data [][]string, compose_data [][]string, conf_length int) map[string]int {
    res := make(map[string]int)
    flag := make(map[string]int)
    var key_compose [][]string

    // 项集去重
    for _, items := range compose_data {
        str_itmes := strings.Join(items,",")
        res[str_itmes] = 0
        flag[str_itmes] += 1
        if flag[str_itmes] == 1{
            key_compose = append(key_compose,items)
        }
    }

    // 计算每个项集的次数
    for _, row := range data {
        for _, items := range key_compose {
            if isin(items,row) {
                res[strings.Join(items,",")] += 1
            }
        }
    }

    res2 := make(map[string]int)
    for k, v := range res{
        if v >= conf_length {
            res2[k] = v
        }
    }
    return res2
}

func merge(a []string, b []string) []string  {
    res := make(map[string]bool)
    var return_list []string
    for _, x := range a {
        res[x] = true
    }
    for _, x := range b {
        res[x] = true
    }
    for k, _ := range res {
        return_list = append(return_list,k)
    }
    sort.Strings(return_list)
    return return_list
}

func compose(set_data map[string]int, level int) [][]string{
    var set_key_data, compose_data [][]string
    for k, _ := range set_data {
        parts := strings.Split(k,",")
        set_key_data = append(set_key_data, parts)
    }
    // 接下来进合并操作
    for i1 , i := range set_key_data{
        for i2 , j := range set_key_data {
            if i2 > i1 {
                // 合并两个集合
                tmp := merge(i,j)
                if len(tmp) == level {
                    compose_data = append(compose_data,tmp)
                }
            }
        }
    }
    return compose_data
}


func main() {
    data := [][]string{
        []string{"棒棒糖","啤酒","雪碧"},
        []string{"尿布","啤酒","可乐"},
        []string{"棒棒糖","尿布","啤酒","可乐"},
        []string{"尿布","可乐"},
    }
    // 支持度
    var CONF float32 = 0.5
    var conf_length int = int(CONF * float32(len(data)))

    // 获得所有样本的种类
    compose_data := get_set(data)

    for i := 2; ; i++{
        // 发现频繁项目集
        frequent := find_frequent(data,compose_data,conf_length)
        compose_data = compose(frequent,i)
        if len(frequent) != 0 {
            fmt.Printf("第%d阶频繁项目集为%v\n",i - 1,frequent)
        } else {
            break
        }
    }
}

Tip

Go语言实现实在是太费劲了,竟然不允许float和int数据之间做加法,呜呜呜~

欢迎大家吐嘈代码。

参考视频

  1. 理论讲解
  2. 手写过程
updatedupdated2022-05-072022-05-07