k近邻算法之kd树优化(kd树的构造和搜索)——基于Python

乖乖的函数 2020-11-13 02:53:31
算法 优化 近邻 kd


前言

kd树的实现原理,我之前的一篇博客写了一下kd树优化的k近邻算法
参考文章:wenffe:python实现KD树

1. kd树的构造

import numpy as np
class Node(object):
"""
定义节点类:
val:节点中的实例点
label:节点中实例的类别
dim:当前节点的分割维度
left:节点的左子树
right:节点的右子树
parent:节点的父节点
"""
def __init__(self,val=None,label=None,dim=None,left=None,right=None,parent=None):
self.val = val
self.label = label
self.dim = dim
self.left = left
self.right = right
self.parent = parent
class kdTree(object):
"""
定义树类:
dataNum:训练集的样本数量
root:构造的kd树的根节点
"""
def __init__(self,dataSet,labelList):
self.dataNum = 0
self.root = self.buildKdTree(dataSet,labelList) ## 注意父节点的传值。
def buildKdTree(self,dataSet, labelList, parentNode=None):
data = np.array(dataSet)
dataNum, dimNum = data.shape # 训练集的样本数,单个数据的维数
label = np.array(labelList).reshape(dataNum,1)
if dataNum == 0: # 如果训练集为数据,返回None
return None
varList = self.getVar(data) # 计算各维度的方差
mid = dataNum // 2 # 找到中位数
maxVarDimIndex = varList.index(max(varList)) # 找到方差最大的维度
sortedDataIndex = data[:,maxVarDimIndex].argsort() # 按照方差最大的维度进行排序
midDataIndex = sortedDataIndex[mid] # 找到该维度处于中间位置的数据,作为根节点
if dataNum == 1: # 如果只有一个数据,那么直接返回根节点就行
self.dataNum = dataNum
return Node(val = data[midDataIndex],label = label[midDataIndex],dim = maxVarDimIndex,left = None,right = None,parent = parentNode)
root = Node(data[midDataIndex],label[midDataIndex],maxVarDimIndex,parent = parentNode,)
"""
划分左子树和右子树,然后递归
"""
leftDataSet = data[sortedDataIndex[:mid]] # 注意是mid而不是不是midDataIndex
leftLabel = label[sortedDataIndex[:mid]]
rightDataSet = data[sortedDataIndex[mid+1:]]
rightLabel = label[sortedDataIndex[mid+1:]]
root.left = self.buildKdTree(leftDataSet,leftLabel,parentNode = root)
root.right = self.buildKdTree(rightDataSet, rightLabel,parentNode = root)
self.dataNum = dataNum # 记录训练记得样本数
return root
def root(self):
return self.root
def getVar(self,data): # 求方差函数
rowLen,colLen = data.shape
varList = []
for i in range(colLen):
varList.append(np.var(data[:,i]))
return varList

2. kd树转换成list和dict

2.1 转换成list

 """
list中的每一个元素都是字典,字典的键分别是:
节点的值、节点的维度、节点的类别、节点的左右子树以及节点的父节点。
每一个字典,都表示一个节点。
"""
def transferTreeToList(self,root,rootList = []):
if root == None:
return None
tempDict = {
}
tempDict["data"] = root.val
tempDict["left"] = root.left.val if root.left else None
tempDict["right"] = root.right.val if root.right else None
tempDict["parent"] = root.parent.val if root.parent else None
tempDict["label"] = root.label[0]
tempDict["dim"] = root.dim
rootList.append(tempDict)
self.transferTreeToList(root.left,rootList)
self.transferTreeToList(root.right,rootList)
return rootList

2.2 转换成字典

 def transferTreeToDict(self,root):
if root == None:
return None
"""
注意:字典的键必须是不可变的,不能使用数组或列表,因此这里使用元祖tuple
"""
dict = {
}
dict[tuple(root.val)] = {
}
dict[tuple(root.val)]["label"] = root.label[0]
# root.label是一个np数组,要想返回值的话用下标即可。
dict[tuple(root.val)]["dim"] = root.dim
dict[tuple(root.val)]["parent"] = root.parent.val if root.parent else None
dict[tuple(root.val)]["left"] = self.transferTreeToDict(root.left)
dict[tuple(root.val)]["right"] = self.transferTreeToDict(root.right)
return dict

3. kd树搜索

3.1 搜索包含目标点的x的叶节点

 def findtheNearestLeafNode(self,root,x):
if root == None: #或者直接用self.dataNum是否等于0即可检查
return None
if root.left == None and root.right == None:
return root
node = root
while True: # 找到叶节点或没有某一个子树的节点
curDim = node.dim
if x[curDim] < node.val[curDim]:
if not node.left:
return node
node = node.left
else:
if not node.right:
return node
node = node.right

3.2 搜索k个近邻点

 """
这里搜索了k个近邻点,和最近邻算法的唯一不同是,需要一个数组保存,当前的前k个近邻点,
而且判定条件,不是最近距离了,而是第K小的距离(结果的守门员),
只有当结果中的节点数不超过K或节点与输入实例的距离小于第K小的距离时才能进入结果数组
"""
def knnSearch(self,x,k):
"""
当整个训练数据集不超K个时,训练数据集都是近邻点。
直接借助一个字典进行统计类别,按照多数决策原则进行判断即可
"""
if self.dataNum <= k:
labelDict = {
}
for element in self.transferTreeToList(self.root):
if element["label"] not in labelDict:
labelDict[element['label']] = 0
labelDict[element["label"]] += 1
sortedLabelList = sorted(labelDict.items(), key=lambda item:item[1],reverse=True) # 对字典排序返回的是由元祖组成的一个列表。
return sortedLabelList[0][0]
"""
先找到最近的叶子节点,然后递归的向上寻找
"""
node = self.findtheNearestLeafNode(self.root,x)
nodeList = []
if node == None: # 如果是空树,直接返回None
return None
x = np.array(x)
distance = np.sqrt(sum((x-node.val)**2)) # 计算最近叶子节点和输入实例的距离
nodeList.append([distance, tuple(node.val), node.label[0]])
# 将距离,节点实例和类别作为一个数组加入结果中。
while True: # 循环
if node == self.root: # 当循环到根节点时,停止循环
break
parentNode = node.parent # 找到当前节点的父节点
parentDis = np.sqrt(sum((x-parentNode.val)**2)) # 计算输入实例x和父节点的距离
if k > len(nodeList) or distance > parentDis:
# 如果当前的结果中不足K个节点或与父节点的距离小于当前列表中距离x最大的距离,
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])# 压入结果列表
nodeList.sort() # 排序
distance = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0] # 更新dis为入队节点中第K小的距离或者直接就是距离最大的距离
if k > len(nodeList) or abs(x[parentNode.dim] - parentNode.val[parentNode.dim]) < distance: # 判断另一子节点区域有没有距离更近的节点
if x[parentNode.dim] < parentNode.val[parentNode.dim]:
otherChild = parentNode.right
# 如果x当前维度的值小于父节点的值
# 说明x在父节点的左子树上,往右节点寻找
self.search(nodeList,otherChild,x,k) # 递归的进行近邻点的寻找
else: # 否则,往左子节点寻找
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
labelDict = {
} # 统计类别,并判断实例点的类别
nodeList = nodeList[:k] if k <= len(nodeList) else nodeList
for element in nodeList:
if element[2] not in labelDict:
labelDict[element[2]] = 0
labelDict[element[2]] += 1
sortedLabel = sorted(labelDict.items(),key=lambda x:x[1],reverse=True)
return sortedLabel[0][0]
def search(self,nodeList,root,x,k):
# 递归的进行k近邻的搜素,和上面的函数几乎一样,只是没有类别的统计和判断
if root == None:
return nodeList
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
x = np.array(x)
node = self.findtheNearestLeafNode(root,x)
distance = np.sqrt(sum((x - node.val)**2))
if k > len(nodeList) or distance < dis:
nodeList.append([distance, tuple(node.val), node.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
while True:
if node == root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or parentDis < dis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
if k > len(nodeList) or abs(x[parentNode.dim]-parentNode.val[parentNode.dim]) < dis:
if x[parentNode.dim] < parentNode.val[parentNode.val]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent

4. 举例

if __name__ == "__main__":
dataArray = [[7, 2], [5, 4], [2, 3], [4, 7], [9, 6], [8, 1]]
label = [[0], [1], [0], [1], [1], [1]]
kd = kdTree(dataArray, label)
Tree = kd.buildKdTree(dataArray, label) ## tree是根节点
list = kd.transferTreeToList(Tree, [])
dict = kd.transferTreeToDict(Tree)
node = kd.findtheNearestLeafNode(Tree, [6, 3])
result = kd.knnSearch([6,3],1)
print(list)
print(result)
"""
输出结果为:[
{'data': array([7, 2]), 'left': array([5, 4]), 'right': array([9, 6]), 'parent': None, 'label': 0, 'dim': 0},
{'data': array([5, 4]), 'left': array([2, 3]), 'right': array([4, 7]), 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([2, 3]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 0, 'dim': 0},
{'data': array([4, 7]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 1, 'dim': 0},
{'data': array([9, 6]), 'left': array([8, 1]), 'right': None, 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([8, 1]), 'left': None, 'right': None, 'parent': array([9, 6]), 'label': 1, 'dim': 0}]
"""
# 类别为:1

5. 完整代码


```python
import numpy as np
class Node(object):
def __init__(self,val=None,label=None,dim=None,left=None,right=None,parent=None):
self.val = val
self.label = label
self.dim = dim
self.left = left
self.right = right
self.parent = parent
class kdTree(object):
def __init__(self,dataSet,labelList):
self.dataNum = 0
self.root = self.buildKdTree(dataSet,labelList) ## 注意父节点的传值。
def buildKdTree(self,dataSet, labelList, parentNode=None):
data = np.array(dataSet)
dataNum, dimNum = data.shape
label = np.array(labelList).reshape(dataNum,1)
if dataNum == 0:
return None
varList = self.getVar(data)
mid = dataNum // 2
maxVarDimIndex = varList.index(max(varList))
sortedDataIndex = data[:,maxVarDimIndex].argsort()
midDataIndex = sortedDataIndex[mid]
if dataNum == 1:
self.dataNum = dataNum
return Node(val = data[midDataIndex],label = label[midDataIndex],dim = maxVarDimIndex,left = None,right = None,parent = parentNode)
root = Node(data[midDataIndex],label[midDataIndex],maxVarDimIndex,parent = parentNode,)
leftDataSet = data[sortedDataIndex[:mid]]##### 注意是mid不是midDataIndex
leftLabel = label[sortedDataIndex[:mid]]
rightDataSet = data[sortedDataIndex[mid+1:]]
rightLabel = label[sortedDataIndex[mid+1:]]
root.left = self.buildKdTree(leftDataSet,leftLabel,parentNode = root)
root.right = self.buildKdTree(rightDataSet, rightLabel,parentNode = root)
self.dataNum = dataNum
return root
def root(self):
return self.root
def transferTreeToDict(self,root):
if root == None:
return None
"""
字典的键必须是不可变的
"""
dict = {
}
dict[tuple(root.val)] = {
}
dict[tuple(root.val)]["label"] = root.label[0] # root.label是一个数组,要想返回值的话用下标即可。
dict[tuple(root.val)]["dim"] = root.dim
dict[tuple(root.val)]["parent"] = root.parent.val if root.parent else None
dict[tuple(root.val)]["left"] = self.transferTreeToDict(root.left)
dict[tuple(root.val)]["right"] = self.transferTreeToDict(root.right)
return dict
def transferTreeToList(self,root,rootList = []):
if root == None:
return None
tempDict = {
}
tempDict["data"] = root.val
tempDict["left"] = root.left.val if root.left else None
tempDict["right"] = root.right.val if root.right else None
tempDict["parent"] = root.parent.val if root.parent else None
tempDict["label"] = root.label[0]
tempDict["dim"] = root.dim
rootList.append(tempDict)
self.transferTreeToList(root.left,rootList)
self.transferTreeToList(root.right,rootList)
return rootList
def getVar(self,data):
rowLen,colLen = data.shape
varList = []
for i in range(colLen):
varList.append(np.var(data[:,i]))
return varList
def findtheNearestLeafNode(self,root,x):
if root == None: #或者直接用self.dataNum是否等于0即可检查
return None
if root.left == None and root.right == None:
return root
node = root
while True:
curDim = node.dim
if x[curDim] < node.val[curDim]:
if not node.left:
return node
node = node.left
else:
if not node.right:
return node
node = node.right
def knnSearch(self,x,k):
if self.dataNum <= k:
labelDict = {
}
for element in self.transferTreeToList(self.root):
if element["label"] not in labelDict:
labelDict[element['label']] = 0
labelDict[element["label"]] += 1
sortedLabelList = sorted(labelDict.items(), key=lambda item:item[1],reverse=True) # 对字典排序返回的是由元祖组成的一个列表。
return sortedLabelList[0][0]
node = self.findtheNearestLeafNode(self.root,x)
nodeList = []
if node == None:
return None
x = np.array(x)
distance = np.sqrt(sum((x-node.val)**2))
nodeList.append([distance, tuple(node.val), node.label[0]])
while True:
if node == self.root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or distance > parentDis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
distance = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
if k > len(nodeList) or abs(x[parentNode.dim] - parentNode.val[parentNode.dim]) < distance:
if x[parentNode.dim] < parentNode.val[parentNode.dim]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
labelDict = {
}
nodeList = nodeList[:k] if k <= len(nodeList) else nodeList
for element in nodeList:
if element[2] not in labelDict:
labelDict[element[2]] = 0
labelDict[element[2]] += 1
sortedLabel = sorted(labelDict.items(),key=lambda x:x[1],reverse=True)
return sortedLabel[0][0]
def search(self,nodeList,root,x,k):
if root == None:
return nodeList
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
x = np.array(x)
node = self.findtheNearestLeafNode(root,x)
distance = np.sqrt(sum((x - node.val)**2))
if k > len(nodeList) or distance < dis:
nodeList.append([distance, tuple(node.val), node.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
while True:
if node == root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or parentDis < dis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
if k > len(nodeList) or abs(x[parentNode.dim]-parentNode.val[parentNode.dim]) < dis:
if x[parentNode.dim] < parentNode.val[parentNode.val]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
if __name__ == "__main__":
dataArray = [[7, 2], [5, 4], [2, 3], [4, 7], [9, 6], [8, 1]]
label = [[0], [1], [0], [1], [1], [1]]
kd = kdTree(dataArray, label)
Tree = kd.buildKdTree(dataArray, label) ## tree是根节点
list = kd.transferTreeToList(Tree, [])
dict = kd.transferTreeToDict(Tree)
node = kd.findtheNearestLeafNode(Tree, [6, 3])
result = kd.knnSearch([6,3],1)
print(list)
print(dict)
print(result)
print(node.val)

版权声明
本文为[乖乖的函数]所创,转载请带上原文链接,感谢
https://blog.csdn.net/ggdhs/article/details/93738324

  1. 利用Python爬虫获取招聘网站职位信息
  2. Using Python crawler to obtain job information of recruitment website
  3. Several highly rated Python libraries arrow, jsonpath, psutil and tenacity are recommended
  4. Python装饰器
  5. Python实现LDAP认证
  6. Python decorator
  7. Implementing LDAP authentication with Python
  8. Vscode configures Python development environment!
  9. In Python, how dare you say you can't log module? ️
  10. 我收藏的有关Python的电子书和资料
  11. python 中 lambda的一些tips
  12. python中字典的一些tips
  13. python 用生成器生成斐波那契数列
  14. python脚本转pyc踩了个坑。。。
  15. My collection of e-books and materials about Python
  16. Some tips of lambda in Python
  17. Some tips of dictionary in Python
  18. Using Python generator to generate Fibonacci sequence
  19. The conversion of Python script to PyC stepped on a pit...
  20. Python游戏开发,pygame模块,Python实现扫雷小游戏
  21. Python game development, pyGame module, python implementation of minesweeping games
  22. Python实用工具,email模块,Python实现邮件远程控制自己电脑
  23. Python utility, email module, python realizes mail remote control of its own computer
  24. 毫无头绪的自学Python,你可能连门槛都摸不到!【最佳学习路线】
  25. Python读取二进制文件代码方法解析
  26. Python字典的实现原理
  27. Without a clue, you may not even touch the threshold【 Best learning route]
  28. Parsing method of Python reading binary file code
  29. Implementation principle of Python dictionary
  30. You must know the function of pandas to parse JSON data - JSON_ normalize()
  31. Python实用案例,私人定制,Python自动化生成爱豆专属2021日历
  32. Python practical case, private customization, python automatic generation of Adu exclusive 2021 calendar
  33. 《Python实例》震惊了,用Python这么简单实现了聊天系统的脏话,广告检测
  34. "Python instance" was shocked and realized the dirty words and advertisement detection of the chat system in Python
  35. Convolutional neural network processing sequence for Python deep learning
  36. Python data structure and algorithm (1) -- enum type enum
  37. 超全大厂算法岗百问百答(推荐系统/机器学习/深度学习/C++/Spark/python)
  38. 【Python进阶】你真的明白NumPy中的ndarray吗?
  39. All questions and answers for algorithm posts of super large factories (recommended system / machine learning / deep learning / C + + / spark / Python)
  40. [advanced Python] do you really understand ndarray in numpy?
  41. 【Python进阶】Python进阶专栏栏主自述:不忘初心,砥砺前行
  42. [advanced Python] Python advanced column main readme: never forget the original intention and forge ahead
  43. python垃圾回收和缓存管理
  44. java调用Python程序
  45. java调用Python程序
  46. Python常用函数有哪些?Python基础入门课程
  47. Python garbage collection and cache management
  48. Java calling Python program
  49. Java calling Python program
  50. What functions are commonly used in Python? Introduction to Python Basics
  51. Python basic knowledge
  52. Anaconda5.2 安装 Python 库(MySQLdb)的方法
  53. Python实现对脑电数据情绪分析
  54. Anaconda 5.2 method of installing Python Library (mysqldb)
  55. Python implements emotion analysis of EEG data
  56. Master some advanced usage of Python in 30 seconds, which makes others envy it
  57. python爬取百度图片并对图片做一系列处理
  58. Python crawls Baidu pictures and does a series of processing on them
  59. python链接mysql数据库
  60. Python link MySQL database