## KD tree optimization of k-nearest neighbor algorithm (construction and search of KD tree) -- Based on Python

## Preface

kd The implementation principle of tree , I wrote a blog before kd Tree optimized k Nearest neighbor algorithm
Reference article ：wenffe：python Realization KD Trees

## 1. kd The structure of the tree

``````import numpy as np
class Node(object):
"""
Define node class ：
val： Instance point in node
label： The class of the instance in the node
dim： The split dimension of the current node
left： The left subtree of the node
right： The right subtree of the node
parent： Parent of node
"""
def __init__(self,val=None,label=None,dim=None,left=None,right=None,parent=None):
self.val = val
self.label = label
self.dim = dim
self.left = left
self.right = right
self.parent = parent
class kdTree(object):
"""
Defining tree classes ：
dataNum： The number of samples in the training set
root： Tectonic kd Root node of tree
"""
def __init__(self,dataSet,labelList):
self.dataNum = 0
self.root = self.buildKdTree(dataSet,labelList) ## Pay attention to the value of the parent node .
def buildKdTree(self,dataSet, labelList, parentNode=None):
data = np.array(dataSet)
dataNum, dimNum = data.shape # The number of samples in the training set , Dimension of individual data
label = np.array(labelList).reshape(dataNum,1)
if dataNum == 0: # If the training set is data , return None
return None
varList = self.getVar(data) # Calculate the variance of each dimension
mid = dataNum // 2 # Find the median
maxVarDimIndex = varList.index(max(varList)) # Find the dimension with the largest variance
sortedDataIndex = data[:,maxVarDimIndex].argsort() # Sort by the dimension with the largest variance
midDataIndex = sortedDataIndex[mid] # Find the data in the middle of the dimension , As root node
if dataNum == 1: # If there is only one data , Just go back to the root node
self.dataNum = dataNum
return Node(val = data[midDataIndex],label = label[midDataIndex],dim = maxVarDimIndex,left = None,right = None,parent = parentNode)
root = Node(data[midDataIndex],label[midDataIndex],maxVarDimIndex,parent = parentNode,)
"""
Divide left subtree and right subtree , Then a recursive
"""
leftDataSet = data[sortedDataIndex[:mid]] # Note that mid It's not midDataIndex
leftLabel = label[sortedDataIndex[:mid]]
rightDataSet = data[sortedDataIndex[mid+1:]]
rightLabel = label[sortedDataIndex[mid+1:]]
root.left = self.buildKdTree(leftDataSet,leftLabel,parentNode = root)
root.right = self.buildKdTree(rightDataSet, rightLabel,parentNode = root)
self.dataNum = dataNum # Record the number of training samples
return root
def root(self):
return self.root
def getVar(self,data): # Find variance function
rowLen,colLen = data.shape
varList = []
for i in range(colLen):
varList.append(np.var(data[:,i]))
return varList
``````

## 2. kd The tree is transformed into list and dict

### 2.1 convert to list

`````` """
list Every element in is a dictionary , The keys of the dictionary are ：
The value of the node 、 The dimension of the node 、 Type of node 、 The left and right subtrees of nodes and the parent nodes of nodes .
Every dictionary , All represent a node .
"""
def transferTreeToList(self,root,rootList = []):
if root == None:
return None
tempDict = {
}
tempDict["data"] = root.val
tempDict["left"] = root.left.val if root.left else None
tempDict["right"] = root.right.val if root.right else None
tempDict["parent"] = root.parent.val if root.parent else None
tempDict["label"] = root.label[0]
tempDict["dim"] = root.dim
rootList.append(tempDict)
self.transferTreeToList(root.left,rootList)
self.transferTreeToList(root.right,rootList)
return rootList
``````

### 2.2 Turn it into a dictionary

`````` def transferTreeToDict(self,root):
if root == None:
return None
"""
Be careful ： The key of the dictionary must be immutable , You can't use arrays or lists , So we use Yuanzu tuple
"""
dict = {
}
dict[tuple(root.val)] = {
}
dict[tuple(root.val)]["label"] = root.label[0]
# root.label It's a np Array , To return a value, use the subscript .
dict[tuple(root.val)]["dim"] = root.dim
dict[tuple(root.val)]["parent"] = root.parent.val if root.parent else None
dict[tuple(root.val)]["left"] = self.transferTreeToDict(root.left)
dict[tuple(root.val)]["right"] = self.transferTreeToDict(root.right)
return dict
``````

## 3. kd Tree search

### 3.1 Search for x Leaf node of

`````` def findtheNearestLeafNode(self,root,x):
if root == None: # Or use it directly self.dataNum Is it equal to 0 Just check
return None
if root.left == None and root.right == None:
return root
node = root
while True: # Find a leaf node or a node without a subtree
curDim = node.dim
if x[curDim] < node.val[curDim]:
if not node.left:
return node
node = node.left
else:
if not node.right:
return node
node = node.right
``````

### 3.2 Search for k Nearest neighbor point

`````` """
Here's a search for k Nearest neighbor point , The only difference from the nearest neighbor algorithm is , You need an array to hold , The present front k Nearest neighbor point ,
And determine the conditions , It's not the nearest distance , It's the first K A small distance （ The goalie of the result ）,
Only if the number of nodes in the result does not exceed K Or the distance between the node and the input instance is less than the K Only a small distance can enter the result array
"""
def knnSearch(self,x,k):
"""
When the whole training data set does not exceed K Time , Training datasets are all neighbors .
Use a dictionary to make statistics , Judging by most decision-making principles is enough
"""
if self.dataNum <= k:
labelDict = {
}
for element in self.transferTreeToList(self.root):
if element["label"] not in labelDict:
labelDict[element['label']] = 0
labelDict[element["label"]] += 1
sortedLabelList = sorted(labelDict.items(), key=lambda item:item[1],reverse=True) # Sorting dictionaries returns a list of primitives .
return sortedLabelList[0][0]
"""
First find the nearest leaf node , Then recursively look up
"""
node = self.findtheNearestLeafNode(self.root,x)
nodeList = []
if node == None: # If it's an empty tree , Go straight back to None
return None
x = np.array(x)
distance = np.sqrt(sum((x-node.val)**2)) # Calculate the distance between the nearest leaf node and the input instance
nodeList.append([distance, tuple(node.val), node.label[0]])
# Distance , Node instances and categories are added to the result as an array .
while True: # loop
if node == self.root: # When looping to the root node , Stop the cycle
break
parentNode = node.parent # Find the parent of the current node
parentDis = np.sqrt(sum((x-parentNode.val)**2)) # Calculation input example x Distance from parent
if k > len(nodeList) or distance > parentDis:
# If the current results are insufficient K The distance between nodes or parent nodes is less than the distance in the current list x The biggest distance ,
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])# Press in the results list
nodeList.sort() # Sort
distance = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0] # to update dis It's the... In the team entry node K A small distance or a direct distance is the biggest distance
if k > len(nodeList) or abs(x[parentNode.dim] - parentNode.val[parentNode.dim]) < distance: # Judge whether there is a closer node in another sub node area
if x[parentNode.dim] < parentNode.val[parentNode.dim]:
otherChild = parentNode.right
# If x The value of the current dimension is less than the value of the parent node
# explain x On the left subtree of the parent node , Go to the right node to find
self.search(nodeList,otherChild,x,k) # Recursively search the nearest neighbor
else: # otherwise , Look for the left child node
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
labelDict = {
} # Statistical categories , And judge the type of instance point
nodeList = nodeList[:k] if k <= len(nodeList) else nodeList
for element in nodeList:
if element[2] not in labelDict:
labelDict[element[2]] = 0
labelDict[element[2]] += 1
sortedLabel = sorted(labelDict.items(),key=lambda x:x[1],reverse=True)
return sortedLabel[0][0]
def search(self,nodeList,root,x,k):
# Recursively k The search of neighbors , It's almost the same as the function above , It's just that there are no categories of statistics and judgments
if root == None:
return nodeList
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k-1][0]
x = np.array(x)
node = self.findtheNearestLeafNode(root,x)
distance = np.sqrt(sum((x - node.val)**2))
if k > len(nodeList) or distance < dis:
nodeList.append([distance, tuple(node.val), node.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
while True:
if node == root:
break
parentNode = node.parent
parentDis = np.sqrt(sum((x-parentNode.val)**2))
if k > len(nodeList) or parentDis < dis:
nodeList.append([parentDis,tuple(parentNode.val),parentNode.label[0]])
nodeList.sort()
dis = nodeList[-1][0] if k > len(nodeList) else nodeList[k - 1][0]
if k > len(nodeList) or abs(x[parentNode.dim]-parentNode.val[parentNode.dim]) < dis:
if x[parentNode.dim] < parentNode.val[parentNode.val]:
otherChild = parentNode.right
self.search(nodeList,otherChild,x,k)
else:
otherChild = parentNode.left
self.search(nodeList, otherChild, x, k)
node = node.parent
``````

### 4. give an example

``````if __name__ == "__main__":
dataArray = [[7, 2], [5, 4], [2, 3], [4, 7], [9, 6], [8, 1]]
label = [[0], [1], [0], [1], [1], [1]]
kd = kdTree(dataArray, label)
Tree = kd.buildKdTree(dataArray, label) ## tree Root node
list = kd.transferTreeToList(Tree, [])
dict = kd.transferTreeToDict(Tree)
node = kd.findtheNearestLeafNode(Tree, [6, 3])
result = kd.knnSearch([6,3],1)
print(list)
print(result)
"""
The output is ：[
{'data': array([7, 2]), 'left': array([5, 4]), 'right': array([9, 6]), 'parent': None, 'label': 0, 'dim': 0},
{'data': array([5, 4]), 'left': array([2, 3]), 'right': array([4, 7]), 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([2, 3]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 0, 'dim': 0},
{'data': array([4, 7]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 1, 'dim': 0},
{'data': array([9, 6]), 'left': array([8, 1]), 'right': None, 'parent': array([7, 2]), 'label': 1, 'dim': 1},
{'data': array([8, 1]), 'left': None, 'right': None, 'parent': array([9, 6]), 'label': 1, 'dim': 0}]
"""
# Category is ：1
``````

## 5. Complete code

