简介
一个KDTree的例子
分割的概念
树的建立
struct kdtree { Node - data - 数据矢量 数据集中某个数据点, 是n维矢量( 这里也就是k维) Range - 空间矢量 该节点所代表的空间范围 split - 整数 垂直于分割超平面的方向轴序号 Left - kd树 由位于该节点分割超平面左子空间内所有数据点所构成的k - d树 Right - kd树 由位于该节点分割超平面右子空间内所有数据点所构成的k - d树 parent - kd树 父节点 }
构建示例
二维样例:{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}
构建步骤:
1、初始化分割轴:
发现x轴的方差较大,所以,最开始的分割轴为x轴。
2、确定当前节点:
对{2,5,9,4,8,7}找中位数,发现{5,7}都可以,这里我们选择7,也就是(7,2);
3、划分双支数据:
在x轴维度上,比较和7的大小,进行划分:
左支:{(2,3),(5,4),(4,7)}
右支:{(9,6),(8,1)}
4、更新分割轴:
一共就两个维度,所以,下一个维度是y轴。
5、确定子节点:
左节点:在左支中找到y轴的中位数(5,4),左支数据更新为{(2,3)},右支数据更新为{(4,7)}
右节点:在右支中找到y轴的中位数(9,6),左支数据更新为{(8,1)},右支数据为null。
6、更新分割轴:
下一个维度为x轴。
7、确定(5,4)的子节点:
左节点:由于只有一个数据,所以,左节点为(2,3)
右节点:由于只有一个数据,所以,右节点为(4,7)
8、确定(9,6)的子节点:
左节点:由于只有一个数据,所以,左节点为(8,1)
右节点:右节点为空。
最终,就可以构建整个的kd-tree了。
邻近搜索
再举一个稍微复杂的例子,我们来查找点(2,4.5),在(7,2)处测试到达(5,4),在(5,4)处测试到达(4,7),然后search_path中的结点为<(7,2), (5,4), (4,7)>,从search_path中取出(4,7)作为当前最佳结点nearest, dist为3.202;
然后回溯至(5,4),以(2,4.5)为圆心,以dist=3.202为半径画一个圆与超平面y=4相交,如下图,所以需要跳到(5,4)的左子空间去搜索。所以要将(2,3)加入到search_path中,现在search_path中的结点为<(7,2), (2, 3)>;另外,(5,4)与(2,4.5)的距离为3.04 < dist = 3.202,所以将(5,4)赋给nearest,并且dist=3.04。
回溯至(2,3),(2,3)是叶子节点,直接平判断(2,3)是否离(2,4.5)更近,计算得到距离为1.5,所以nearest更新为(2,3),dist更新为(1.5)
回溯至(7,2),同理,以(2,4.5)为圆心,以dist=1.5为半径画一个圆并不和超平面x=7相交, 所以不用跳到结点(7,2)的右子空间去搜索。
代码清单
#include <iostream> #include <algorithm> #include <stack> #include <math.h> using namespace std; /*function of this program: build a 2d tree using the input training data the input is exm_set which contains a list of tuples (x,y) the output is a 2d tree pointer*/ struct data { double x = 0; double y = 0; }; struct Tnode { struct data dom_elt; int split; struct Tnode * left; struct Tnode * right; }; bool cmp1(data a, data b){ return a.x < b.x; } bool cmp2(data a, data b){ return a.y < b.y; } bool equal(data a, data b){ if (a.x == b.x && a.y == b.y) { return true; } else{ return false; } } void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){ /*compute the variance on every dimension. Set split as the dismension that have the biggest variance. Then choose the instance which is the median on this split dimension.*/ /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/ double tmp1,tmp2; tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x; tmp2 += 1.0 / (double)size * exm_set[i].x; } double v1 = tmp1 - tmp2 * tmp2; //compute variance on the x dimension tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y; tmp2 += 1.0 / (double)size * exm_set[i].y; } double v2 = tmp1 - tmp2 * tmp2; //compute variance on the y dimension split = v1 > v2 ? 0:1; //set the split dimension if (split == 0) { sort(exm_set,exm_set + size, cmp1); } else{ sort(exm_set,exm_set + size, cmp2); } //set the split point value SplitChoice.x = exm_set[size / 2].x; SplitChoice.y = exm_set[size / 2].y; } Tnode* build_kdtree(data exm_set[], int size, Tnode* T){ //call function ChooseSplit to choose the split dimension and split point if (size == 0){ return NULL; } else{ int split; data dom_elt; ChooseSplit(exm_set, size, split, dom_elt); data exm_set_right [100]; data exm_set_left [100]; int sizeleft ,sizeright; sizeleft = sizeright = 0; if (split == 0) { for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } else{ for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } T = new Tnode; T->dom_elt.x = dom_elt.x; T->dom_elt.y = dom_elt.y; T->split = split; T->left = build_kdtree(exm_set_left, sizeleft, T->left); T->right = build_kdtree(exm_set_right, sizeright, T->right); return T; } } double Distance(data a, data b){ double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); return sqrt(tmp); } void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){ //1. 如果Kd是空的,则设dist为无穷大返回 //2. 向下搜索直到叶子结点 stack<Tnode*> search_path; Tnode* pSearch = Kd; data nearest; double dist; while(pSearch != NULL) { //pSearch加入到search_path中; search_path.push(pSearch); if (pSearch->split == 0) { if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } else{ if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } } //取出search_path最后一个赋给nearest nearest.x = search_path.top()->dom_elt.x; nearest.y = search_path.top()->dom_elt.y; search_path.pop(); dist = Distance(nearest, target); //3. 回溯搜索路径 Tnode* pBack; while(search_path.size() != 0) { //取出search_path最后一个结点赋给pBack pBack = search_path.top(); search_path.pop(); if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack为叶子结点 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } } else { int s = pBack->split; if (s == 0) { if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */ if(pSearch != NULL) //pSearch加入到search_path中 search_path.push(pSearch); } } else { if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */ if(pSearch != NULL) // pSearch加入到search_path中 search_path.push(pSearch); } } } } nearestpoint.x = nearest.x; nearestpoint.y = nearest.y; distance = dist; } int main(){ data exm_set[100]; //assume the max training set size is 100 double x,y; int id = 0; cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl; while (cin>>x>>y){ if (x == -1) { break; } else{ exm_set[id].x = x; exm_set[id].y = y; id++; } } struct Tnode * root = NULL; root = build_kdtree(exm_set, id, root); data nearestpoint; double distance; data target; cout <<"Enter search point"<<endl; while (cin>>target.x>>target.y) { searchNearest(root, target, nearestpoint, distance); cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl; cout <<"Enter search point"<<endl; } }
《本文》有 0 条评论