kd The implementation principle of tree , I wrote a blog before kd Tree optimized k Nearest neighbor algorithm
Reference article :wenffe:python Realization KD Trees
import numpy as np
class Node(object):
"""
Define node class :
val: Instance point in node
label: The class of the instance in the node
dim: The split dimension of the current node
left: The left subtree of the node
right: The right subtree of the node
parent: Parent of node
"""
def __init__(self,val=None,label=None,dim=None,left=None,right=None,parent=None):
self.val = val
self.label = label
self.dim = dim
self.left = left
self.right = right
self.parent = parent
class kdTree(object):
"""
Defining tree classes :
dataNum: The number of samples in the training set
root: Tectonic kd Root node of tree
"""
def __init__(self,dataSet,labelList):
self.dataNum = 0
self.root = self.buildKdTree(dataSet,labelList) ## Pay attention to the value of the parent node .
def buildKdTree(self,dataSet, labelList, parentNode=None):
data = np.array(dataSet)
dataNum, dimNum = data.shape # The number of samples in the training set , Dimension of individual data
label = np.array(labelList).reshape(dataNum,1)
if dataNum == 0: # If the training set is data , return None
return None
varList = self.getVar(data) # Calculate the variance of each dimension
mid = dataNum // 2 # Find the median
maxVarDimIndex = varList.index(max(varList)) # Find the dimension with the largest variance
sortedDataIndex = data[:,maxVarDimIndex].argsort() # Sort by the dimension with the largest variance
midDataIndex = sortedDataIndex[mid] # Find the data in the middle of the dimension , As root node
if dataNum == 1: # If there is only one data , Just go back to the root node
self.dataNum = dataNum
return Node(val = data[midDataIndex],label = label[midDataIndex],dim = maxVarDimIndex,left = None,right = None,parent = parentNode)
root = Node(data[midDataIndex],label[midDataIndex],maxVarDimIndex,parent = parentNode,)
"""
Divide left subtree and right subtree , Then a recursive
"""
leftDataSet = data[sortedDataIndex[:mid]] # Note that mid It's not midDataIndex
leftLabel = label[sortedDataIndex[:mid]]
rightDataSet = data[sortedDataIndex[mid+1:]]
rightLabel = label[sortedDataIndex[mid+1:]]
root.left = self.buildKdTree(leftDataSet,leftLabel,parentNode = root)
root.right = self.buildKdTree(rightDataSet, rightLabel,parentNode = root)
self.dataNum = dataNum # Record the number of training samples
return root
def root(self):
return self.root
def getVar(self,data): # Find variance function
rowLen,colLen = data.shape
varList = []
for i in range(colLen):
varList.append(np.var(data[:,i]))
return varList
"""
list Every element in is a dictionary , The keys of the dictionary are :
The value of the node 、 The dimension of the node 、 Type of node 、 The left and right subtrees of nodes and the parent nodes of nodes .
Every dictionary , All represent a node .
"""
def transferTreeToList(self,root,rootList = []):
if root == None:
return None
tempDict = {
}
tempDict["data"] = root.val
tempDict["left"] = root.left.val if root.left else None
tempDict["right"] = root.right.val if root.right else None
tempDict["parent"] = root.parent.val if root.parent else None
tempDict["label"] = root.label[0]
tempDict["dim"] = root.dim
rootList.append(tempDict)
self.transferTreeToList(root.left,rootList)
self.transferTreeToList(root.right,rootList)
return rootList
def transferTreeToDict(self,root):
if root == None:
return None
"""
Be careful : The key of the dictionary must be immutable , You can't use arrays or lists , So we use Yuanzu tuple
"""
dict = {
}
dict[tuple(root.val)] = {
}
dict[tuple(root.val)]["label"] = root.label[0]
# root.label It's a np Array , To return a value, use the subscript .
dict[tuple(root.val)]["dim"] = root.dim
dict[tuple(root.val)]["parent"] = root.parent.val if root.parent else None
dict[tuple(root.val)]["left"] = self.transferTreeToDict(root.left)
dict[tuple(root.val)]["right"] = self.transferTreeToDict(root.right)
return dict
def findtheNearestLeafNode(self,root,x):
if root == None: # Or use it directly self.dataNum Is it equal to 0 Just check
return None
if root.left == None and root.right == None:
return root
node = root
while True: # Find a leaf node or a node without a subtree
curDim = node.dim
if x[curDim] < node.val[curDim]:
if not node.left:
return node
node = node.left
else:
if not node.right:
return node
node = node.right
"""
Here's a search for k Nearest neighbor point , The only difference from the nearest neighbor algorithm is , You need an array to hold , The present front k Nearest neighbor point ,
And determine the conditions , It's not the nearest distance , It's the first K A small distance ( The goalie of the result ),
Only if the number of nodes in the result does not exceed K Or the distance between the node and the input instance is less than the K Only a small distance can enter the result array
"""
def knnSearch(self,x,k):
"""
When the whole training data set does not exceed K Time , Training datasets are all neighbors .
Use a dictionary to make statistics , Judging by most decision-making principles is enough
"""
if self.dataNum <= k:
labelDict = {
}
for element in self.transferTreeToList(self.root):
if element["label"] not in labelDict:
labelDict[element['label']] = 0
labelDict[element["label"]] += 1
sortedLabelList = sorted(labelDict.items(), key=lambda item:item[1],reverse=True) # Sorting dictionaries returns a list of primitives .
return sortedLabelList[0][0]
"""
First find the nearest leaf node , Then recursively look up
"""
node = self.findtheNearestLeafNode(self.root,x)
nodeList = []
if node == None: # If it's an empty tree , Go straight back to None
return None
x = np.array(x)
distance = np.sqrt(sum((x-node.val)**2)) # Calculate the distance between the nearest leaf node and the input instance
nodeList.append([distance, tuple(node.val), node.label[0]])
# Distance , Node instances and categories are added to the result as an array .
while True: # loop
if node == self.root: # When looping to the root node , Stop the cycle
break
parentNode = node.parent # Find the parent of the current node
parentDis = np.sqrt(sum((x-parentNode.val)**2)) # Calculation input example x Distance from parent
if k > len(nodeList) or distance > parentDis:
# If the current results are insufficient K The distance between nodes or parent nodes is less than the distance in the current list x The biggest distance ,
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])# Press in the results list
nodeList.sort() # Sort
distance = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0] # to update dis It's the... In the team entry node K A small distance or a direct distance is the biggest distance
if k > len(nodeList) or abs(x[parentNode.dim] - parentNode.val[parentNode.dim]) < distance: # Judge whether there is a closer node in another sub node area
if x[parentNode.dim] < parentNode.val[parentNode.dim]:
otherChild = parentNode.right
# If x The value of the current dimension is less than the value of the parent node
# explain x On the left subtree of the parent node , Go to the right node to find
self.search(nodeList,otherChild,x,k) # Recursively search the nearest neighbor
else: # otherwise , Look for the left child node
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
labelDict = {
} # Statistical categories , And judge the type of instance point
nodeList = nodeList[:k] if k <= len(nodeList) else nodeList
for element in nodeList:
if element[2] not in labelDict:
labelDict[element[2]] = 0
labelDict[element[2]] += 1
sortedLabel = sorted(labelDict.items(),key=lambda x:x[1],reverse=True)
return sortedLabel[0][0]
def search(self,nodeList,root,x,k):
# Recursively k The search of neighbors , It's almost the same as the function above , It's just that there are no categories of statistics and judgments
if root == None:
return nodeList
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
x = np.array(x)
node = self.findtheNearestLeafNode(root,x)
distance = np.sqrt(sum((x - node.val)**2))
if k > len(nodeList) or distance < dis:
nodeList.append([distance, tuple(node.val), node.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
while True:
if node == root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or parentDis < dis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
if k > len(nodeList) or abs(x[parentNode.dim]-parentNode.val[parentNode.dim]) < dis:
if x[parentNode.dim] < parentNode.val[parentNode.val]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
if __name__ == "__main__":
dataArray = [[7, 2], [5, 4], [2, 3], [4, 7], [9, 6], [8, 1]]
label = [[0], [1], [0], [1], [1], [1]]
kd = kdTree(dataArray, label)
Tree = kd.buildKdTree(dataArray, label) ## tree Root node
list = kd.transferTreeToList(Tree, [])
dict = kd.transferTreeToDict(Tree)
node = kd.findtheNearestLeafNode(Tree, [6, 3])
result = kd.knnSearch([6,3],1)
print(list)
print(result)
"""
The output is :[
{'data': array([7, 2]), 'left': array([5, 4]), 'right': array([9, 6]), 'parent': None, 'label': 0, 'dim': 0},
{'data': array([5, 4]), 'left': array([2, 3]), 'right': array([4, 7]), 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([2, 3]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 0, 'dim': 0},
{'data': array([4, 7]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 1, 'dim': 0},
{'data': array([9, 6]), 'left': array([8, 1]), 'right': None, 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([8, 1]), 'left': None, 'right': None, 'parent': array([9, 6]), 'label': 1, 'dim': 0}]
"""
# Category is :1
```python
import numpy as np
class Node(object):
def __init__(self,val=None,label=None,dim=None,left=None,right=None,parent=None):
self.val = val
self.label = label
self.dim = dim
self.left = left
self.right = right
self.parent = parent
class kdTree(object):
def __init__(self,dataSet,labelList):
self.dataNum = 0
self.root = self.buildKdTree(dataSet,labelList) ## Pay attention to the value of the parent node .
def buildKdTree(self,dataSet, labelList, parentNode=None):
data = np.array(dataSet)
dataNum, dimNum = data.shape
label = np.array(labelList).reshape(dataNum,1)
if dataNum == 0:
return None
varList = self.getVar(data)
mid = dataNum // 2
maxVarDimIndex = varList.index(max(varList))
sortedDataIndex = data[:,maxVarDimIndex].argsort()
midDataIndex = sortedDataIndex[mid]
if dataNum == 1:
self.dataNum = dataNum
return Node(val = data[midDataIndex],label = label[midDataIndex],dim = maxVarDimIndex,left = None,right = None,parent = parentNode)
root = Node(data[midDataIndex],label[midDataIndex],maxVarDimIndex,parent = parentNode,)
leftDataSet = data[sortedDataIndex[:mid]]##### Note that mid No midDataIndex
leftLabel = label[sortedDataIndex[:mid]]
rightDataSet = data[sortedDataIndex[mid+1:]]
rightLabel = label[sortedDataIndex[mid+1:]]
root.left = self.buildKdTree(leftDataSet,leftLabel,parentNode = root)
root.right = self.buildKdTree(rightDataSet, rightLabel,parentNode = root)
self.dataNum = dataNum
return root
def root(self):
return self.root
def transferTreeToDict(self,root):
if root == None:
return None
"""
The key of the dictionary must be immutable
"""
dict = {
}
dict[tuple(root.val)] = {
}
dict[tuple(root.val)]["label"] = root.label[0] # root.label Is an array , To return a value, use the subscript .
dict[tuple(root.val)]["dim"] = root.dim
dict[tuple(root.val)]["parent"] = root.parent.val if root.parent else None
dict[tuple(root.val)]["left"] = self.transferTreeToDict(root.left)
dict[tuple(root.val)]["right"] = self.transferTreeToDict(root.right)
return dict
def transferTreeToList(self,root,rootList = []):
if root == None:
return None
tempDict = {
}
tempDict["data"] = root.val
tempDict["left"] = root.left.val if root.left else None
tempDict["right"] = root.right.val if root.right else None
tempDict["parent"] = root.parent.val if root.parent else None
tempDict["label"] = root.label[0]
tempDict["dim"] = root.dim
rootList.append(tempDict)
self.transferTreeToList(root.left,rootList)
self.transferTreeToList(root.right,rootList)
return rootList
def getVar(self,data):
rowLen,colLen = data.shape
varList = []
for i in range(colLen):
varList.append(np.var(data[:,i]))
return varList
def findtheNearestLeafNode(self,root,x):
if root == None: # Or use it directly self.dataNum Is it equal to 0 Just check
return None
if root.left == None and root.right == None:
return root
node = root
while True:
curDim = node.dim
if x[curDim] < node.val[curDim]:
if not node.left:
return node
node = node.left
else:
if not node.right:
return node
node = node.right
def knnSearch(self,x,k):
if self.dataNum <= k:
labelDict = {
}
for element in self.transferTreeToList(self.root):
if element["label"] not in labelDict:
labelDict[element['label']] = 0
labelDict[element["label"]] += 1
sortedLabelList = sorted(labelDict.items(), key=lambda item:item[1],reverse=True) # Sorting dictionaries returns a list of primitives .
return sortedLabelList[0][0]
node = self.findtheNearestLeafNode(self.root,x)
nodeList = []
if node == None:
return None
x = np.array(x)
distance = np.sqrt(sum((x-node.val)**2))
nodeList.append([distance, tuple(node.val), node.label[0]])
while True:
if node == self.root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or distance > parentDis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
distance = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
if k > len(nodeList) or abs(x[parentNode.dim] - parentNode.val[parentNode.dim]) < distance:
if x[parentNode.dim] < parentNode.val[parentNode.dim]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
labelDict = {
}
nodeList = nodeList[:k] if k <= len(nodeList) else nodeList
for element in nodeList:
if element[2] not in labelDict:
labelDict[element[2]] = 0
labelDict[element[2]] += 1
sortedLabel = sorted(labelDict.items(),key=lambda x:x[1],reverse=True)
return sortedLabel[0][0]
def search(self,nodeList,root,x,k):
if root == None:
return nodeList
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
x = np.array(x)
node = self.findtheNearestLeafNode(root,x)
distance = np.sqrt(sum((x - node.val)**2))
if k > len(nodeList) or distance < dis:
nodeList.append([distance, tuple(node.val), node.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
while True:
if node == root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or parentDis < dis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
if k > len(nodeList) or abs(x[parentNode.dim]-parentNode.val[parentNode.dim]) < dis:
if x[parentNode.dim] < parentNode.val[parentNode.val]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
if __name__ == "__main__":
dataArray = [[7, 2], [5, 4], [2, 3], [4, 7], [9, 6], [8, 1]]
label = [[0], [1], [0], [1], [1], [1]]
kd = kdTree(dataArray, label)
Tree = kd.buildKdTree(dataArray, label) ## tree Root node
list = kd.transferTreeToList(Tree, [])
dict = kd.transferTreeToDict(Tree)
node = kd.findtheNearestLeafNode(Tree, [6, 3])
result = kd.knnSearch([6,3],1)
print(list)
print(dict)
print(result)
print(node.val)