TWCOS Kernel

Artifact Content
Login

Artifact c7d49e1daa41a5721b376e4c9e101bfa4d56f9bf:


#include "tree.h"

#if INTERFACE
enum treemode { TREE_SPLAY=1, TREE_TREAP, TREE_COUNT };

#endif

exception_def OutOfBoundsException = { "OutOfBoundsException", &RuntimeException };

typedef struct tree_t tree_t;
typedef struct node_t node_t;

struct node_t {
	map_key key;
	map_data data;

	/* Count of nodes, including this one */
	int count;

	/* treap node priority */
	int priority;

	/* Nodes connectivity */
	node_t * parent;
	node_t * left;
	node_t * right;
};

struct tree_t {
	map_t map;

	node_t * root;

	int mode;

	int (*comp)(map_key k1, map_key k2);
};

static node_t * tree_node_first(tree_t * tree);
static const node_t * node_next( const node_t * const current );
#if 0
static void tree_mark(void * p)
{
	tree_t * tree = (tree_t*)p;
	slab_gc_mark(tree->root);
}

static void node_mark(void * p)
{
	node_t * node = (node_t*)p;
	slab_gc_mark((void*)node->key);
	slab_gc_mark((void*)node->data);
	slab_gc_mark(node->left);
	slab_gc_mark(node->right);
}

void debug_finalize(void * p)
{
}

static slab_type_t nodes[1] = { SLAB_TYPE(sizeof(node_t), node_mark, debug_finalize)};
static slab_type_t trees[1] = { SLAB_TYPE(sizeof(tree_t), tree_mark, debug_finalize)};
#endif
static slab_type_t nodes[1] = { SLAB_TYPE(sizeof(node_t), 0, 0)};
static slab_type_t trees[1] = { SLAB_TYPE(sizeof(tree_t), 0, 0)};

/*
 * Rotate left:
 *    B            D
 *   / \          / \
 *  A   D   =>   B   E
 *     / \      / \
 *    C   E    A   C
 */
static void node_rotate_left( node_t * b )
{
        node_t * a = b->left;
        node_t * d = b->right;
        node_t * c = d->left;
        node_t * e = d->right;

        assert(d);
        assert(d->parent == b);
        assert(NULL == c || c->parent == d);
        assert(NULL == b->parent || b->parent->left == b || b->parent->right == b);

        /* Link d into b's parent */
        d->parent = b->parent;

        /* Reparent b */
        b->parent = d;
        d->left = b;

        /* Reparent c if required */
        b->right = c;
        if (b->right) {
                b->right->parent = b;
        }

        /* Link into d parent */
        if (d->parent) {
                if (b == d->parent->left) {
                        d->parent->left = d;
                } else {
                        d->parent->right = d;
                }
        }

        /* Fix node counts */
        b->count = 1 + ((a) ? a->count : 0) + ((c) ? c->count : 0);
        d->count = 1 + b->count + ((e) ? e->count : 0);

        assert(NULL == d->parent || d->parent->left == d || d->parent->right == d);
}

/*
 * Rotate right:
 *      D        B
 *     / \      / \
 *    B   E => A   D
 *   / \          / \
 *  A   C        C   E
 */
static void node_rotate_right( node_t * d )
{
        node_t * b = d->left;
        node_t * a = b->left;
        node_t * c = b->right;
        node_t * e = d->right;

        assert(b);
        assert(b->parent == d);
        assert(NULL == c || c->parent == b);
        assert(NULL == d->parent || d->parent->left == d || d->parent->right == d);

        /* Link b into d's parent */
        b->parent = d->parent;

        /* Reparent d */
        d->parent = b;
        b->right = d;

        /* Reparent c if required */
        d->left = c;
        if (d->left) {
                d->left->parent = d;
        }

        /* Link into b parent */
        if (b->parent) {
                if (d == b->parent->left) {
                        b->parent->left = b;
                } else {
                        b->parent->right = b;
                }
        }

        /* Fix node counts */
        d->count = 1 + ((c) ? c->count : 0) + ((e) ? e->count : 0);
        b->count = 1 + d->count + ((a) ? a->count : 0);

        assert(NULL == b->parent || b->parent->left == b || b->parent->right == b);
}

static int node_is_left( const node_t * const node )
{
        return (node->parent && node == node->parent->left);
}

static int node_is_right( const node_t * const node )
{
        return (node->parent && node == node->parent->right);
}

static void node_splay( node_t * node )
{
        while(node->parent) {
                if (node_is_left(node)) {
                        if (node_is_left(node->parent)) {
                                node_rotate_right(node->parent->parent);
                                node_rotate_right(node->parent);
                        } else if (node_is_right(node->parent)) {
                                node_rotate_right(node->parent);
                                node_rotate_left(node->parent);
                        } else {
                                node_rotate_right(node->parent);
                        }
                } else if (node_is_right(node)) {
                        if (node_is_right(node->parent)) {
                                node_rotate_left(node->parent->parent);
                                node_rotate_left(node->parent);
                        } else if (node_is_right(node->parent)) {
                                node_rotate_left(node->parent);
                                node_rotate_right(node->parent);
                        } else {
                                node_rotate_left(node->parent);
                        }
                }
        }
}

static void node_prioritize( node_t * node )
{
        node->priority = ((uintptr_t)node * 997) & 0xff;
        while(node->parent && node->priority < node->parent->priority) {
                if (node_is_left(node)) {
                        node_rotate_right(node->parent);
                } else if (node_is_right(node)) {
                        node_rotate_left(node->parent);
                }
        }
}

static int node_count( node_t * node )
{
        return (node) ? node->count : 0;
}

static void node_count_balance( node_t * node )
{
        node_t * balance = node;
        while(balance && balance->parent) {
                if (node_is_right(balance) && (node_count(balance->parent->left) < node_count(balance->right)) ) {
                        node_rotate_left(balance->parent);
                } else if (node_is_left(balance) && (node_count(balance->parent->right) < node_count(balance->left)) ) {
                        node_rotate_right(balance->parent);
                }
                balance = balance->parent;
        }
}

static void node_simple_balance( node_t * node )
{
	node_t * parent = node->parent;

	while(parent) {
		int i = 0;

		i |= (parent->left != 0);
		i <<= 1;
		i |= (parent->right != 0);
		i <<= 1;
		i |= (node->left != 0);
		i <<= 1;
		i |= (node->right != 0);

		switch(i) {
		/* 1001 */
		case 9:
                        node_rotate_left(node);
		/* 1010 */
		case 10:
                        node_rotate_right(parent);
			break;
		/* 0110 */
		case 6:
                        node_rotate_right(node);
		/* 0101 */
		case 5:
                        node_rotate_left(parent);
			break;
		}
		node = parent;
		parent = node->parent;
	}
}

static node_t * tree_node_new( node_t * parent, map_key key, map_data data )
{
        node_t * node = slab_alloc(nodes);
        node->key = key;
        node->data = data;
        node->parent = parent;
        node->count = 1;
        while(parent) {
                parent->count++;
                parent = parent->parent;
        }
        node->left = node->right = NULL;

        return node;
}

static void tree_destroy( const map_t * map )
{
}

static node_t * node_prev( node_t * current )
{
	node_t * node = current;

	/* Case 1 - We have a left node, nodes which are after our parent */
	if (node->left) {
		node = node->left;
		while(node->right) {
			node = node->right;
		}

		return node;
	}

	while(node_is_left(node)) {
		node = node->parent;
	}

	/* Case 2 - We're a right node, our parent is previous */
	if (node_is_right(node)) {
		return node->parent;
	}

	return 0;
}

static const node_t * node_next( const node_t * const current )
{
	const node_t * node = current;

	/* Case 1 - We have a right node, nodes which are before our parent */
	if (node->right) {
		node = node->right;
		while(node->left) {
			node = node->left;
		}

		return node;
	}

	while(node_is_right(node)) {
		node = node->parent;
	}

	/* Case 2 - We're a left node, our parent is next */
	if (node_is_left(node)) {
		return node->parent;
	}

	return 0;
}

#if 0
static void tree_walk_node( node_t * node, walk_func func, void * p )
{
        if (NULL == node) {
                return;
        }

	/* Find left most node */
	while(node->left) {
		node = node->left;
	}

	/* Step through all nodes */
	while(node) {
		func(p, node->key, node->data);
		node = node_next(node);
	}
}
#endif

static node_t * tree_node_first(tree_t * tree)
{
	node_t * node = tree->root;

	if (node) {
		while(node->left) {
			node = node->left;
		}
	}

	return node;
}

static node_t * tree_node_last(tree_t * tree)
{
	node_t * node = tree->root;

	if (node) {
		while(node->right) {
			node = node->right;
		}
	}

	return node;
}

static void tree_walk_nodes( const node_t * start, const node_t * end, const walk_func func, const void * p)
{
	const node_t * node = start;
	while (node) {
		func(p, node->key, node->data);
		node = (node == end) ? 0 : node_next(node);
	}
}

void tree_walk( const map_t * map, const walk_func func, const void * p )
{
        tree_t * tree = container_of(map, tree_t, map);
	node_t * start = tree_node_first(tree);
	node_t * end = tree_node_last(tree);

        tree_walk_nodes(start, end, func, p);
}

static const node_t * tree_get_node( tree_t * tree, map_key key, map_eq_test cond );
void tree_walk_range( const map_t * map, walk_func func, const void * p, map_key from, map_key to )
{
        tree_t * tree = container_of(map, tree_t, map);
	const node_t * start = (from) ? tree_get_node(tree, from, MAP_GE) : tree_node_first(tree);
	const node_t * end = (to) ? tree_get_node(tree, to, MAP_LT) : tree_node_last(tree);

        tree_walk_nodes(start, end, func, p);
}

static void node_verify( tree_t * tree, node_t * node )
{
	if (1) {
		if (NULL == node) {
			return;
		}

		if (node->count == 1) {
			assert(0 == node->left);
			assert(0 == node->right);
		}

		int count = 1;

		/* Check child linkage */
		if (node->left) {
			count += node->left->count;
			assert(node == node->left->parent);
			node_verify(tree, node->left);
		}
		if (node->right) {
			count += node->right->count;
			assert(node == node->right->parent);
			node_verify(tree, node->right);
		}

		assert(count == node->count);
	}
}

static void tree_verify( tree_t * tree, node_t * node )
{
	if (1) {
		/*
		 * If we're passed a node, check that the node
		 * is linked to the root.
		 */
		if (node) {
			node_t * root = node;

			while(root->parent) {
				root = root->parent;
			}

			assert(tree->root == root);
		}

		/*
		 * Check node counts
		 */
		if (tree->root) {
			assert(tree->root->count == node_count(tree->root));
			assert(tree->root->parent == 0);
		}

		/*
		 * Verify the root node, and verify the rest of the
		 * tree.
		 */
		node_verify(tree, tree->root);
	}
}

static map_data tree_put( const map_t * map, map_key key, map_data data )
{
        tree_t * tree = container_of(map, tree_t, map);
        node_t * node = tree->root;
        node_t * parent = NULL;
        node_t * * plast = &tree->root;

        tree_verify(tree, NULL);
        while(node) {
		int diff = tree->comp(key, node->key);

                if (diff<0) {
                        parent = node;
                        plast = &node->left;
                        node = node->left;
                } else if (diff>0) {
                        parent = node;
                        plast = &node->right;
                        node = node->right;
                } else {
                        /* Replace existing data */
                        map_data olddata = node->data;
                        node->key = key;
                        node->data = data;
                        tree_verify(tree, node);
                        return olddata;
                }
        }

        /*
         * By here, we have new data
         */
        *plast = node = tree_node_new(parent, key, data);

        /*
         * Do any "balancing"
         */
        switch(tree->mode) {
        case TREE_SPLAY:
                node_splay(node);
                break;
        case TREE_TREAP:
                node_prioritize(node);
                break;
        case TREE_COUNT:
                node_count_balance(node);
                break;
	default:
		node_simple_balance(node);
		break;
        }
        if (NULL == node->parent) {
                tree->root = node;
        }

        if (tree->root->parent) {
                /* Tree has new root */
                while(tree->root->parent) {
                        tree->root = tree->root->parent;
                }
        }

        tree_verify(tree, node);

	return 0;
}

static const node_t * tree_get_node( tree_t * tree, map_key key, map_eq_test cond )
{
	node_t * node = tree->root;

	/* FIXME: All this logic needs checking! */
	while(node) {
		int diff = tree->comp(key, node->key);

		if (diff<0) {
			if (node->left) {
				node = node->left;
			} else {
				switch(cond) {
				case MAP_GT: case MAP_GE:
					return node;
				case MAP_LT: case MAP_LE:
					return node_prev(node);
				default:
					return 0;
				}
			}
		} else if (diff>0) {
			if (node->right) {
				node = node->right;
			} else {
				switch(cond) {
				case MAP_GT: case MAP_GE:
					return node_next(node);
				case MAP_LT: case MAP_LE:
					return node;
				default:
					return 0;
				}
			}
		} else {
			switch(cond) {
			case MAP_LT:
				return node_prev(node);
			case MAP_GT:
				return node_next(node);
			default:
				if (TREE_SPLAY == tree->mode) {
					node_splay(node);
					if (tree->root->parent) {
						/* Tree has new root */
						while(tree->root->parent) {
							tree->root = tree->root->parent;
						}
					}
				}
				return node;
			}
		}
	}

	return 0;
}

static map_data tree_get_data( tree_t * tree, map_key key, map_eq_test cond )
{
	const node_t * node = tree_get_node(tree, key, cond);

	return (node) ? node->data : 0;
}

static map_data tree_get(const map_t * map, map_key key, map_eq_test cond )
{
        tree_t * tree = container_of(map, tree_t, map);
	return tree_get_data(tree, key, cond);
}

static map_data tree_remove( const map_t * map, map_key key )
{
        tree_t * tree = container_of(map, tree_t, map);
        node_t * node = tree->root;

        tree_verify(tree, NULL);
        while(node) {
		int diff = tree->comp(key, node->key);

                if (diff<0) {
                        node = node->left;
                } else if (diff>0) {
                        node = node->right;
                } else {
                        map_data data = node->data;
                        node_t * parent = NULL;

                        /* Bubble the node down to a leaf */
                        while(node->left || node->right) {
                                if (node->left) {
                                        node_rotate_right(node);
                                } else {
                                        node_rotate_left(node);
                                }
                                if (NULL == node->parent->parent) {
                                        tree->root = node->parent;
                                }
                        }
                        /* Node has no children, just delete */
                        assert(1 == node->count);
                        if (node->parent && node == node->parent->left) {
                                node->parent->left = NULL;
                        }
                        if (node->parent && node == node->parent->right) {
                                node->parent->right = NULL;
                        }
                        if (NULL == node->parent) {
                                tree->root = NULL;
                        }

                        /* Decrement the counts on parent nodes */
                        parent = node->parent;
                        while(parent) {
                                parent->count--;
                                parent = parent->parent;
                        }

                        tree_verify(tree, NULL);

                        return data;
                }
        }

	return 0;
}


static iterator_t * tree_iterator( const map_t * map)
{
	return 0;
}

static node_t * node_ordinal(node_t * root, int i)
{
	node_t * node = root;

	if (i >= node->count) {
		KTHROWF(OutOfBoundsException, "Out of bounds: %d >= %d\n", i, node->count);
		return 0;
	}

	while(node) {
		int count_left = node_count(node->left);
		if (i<count_left) {
			node = node->left; 
		} else if (i == count_left) {
			return node;
		} else {
			node = node->right;
			i -= (count_left+1);
		}
	}

	/* FIXME: Can't happen! */
	return 0;
}

static node_t * node_optimize(node_t * root)
{
	if (0 == root) {
		return 0;
	}

	node_t * parent = root->parent;
	node_t * node = node_ordinal(root, root->count/2);

	node->priority = (parent) ? parent->priority + 10 : 10;

	while(node->parent != parent) {
		if (node_is_left(node)) {
			node_rotate_right(node->parent);
		} else {
			node_rotate_left(node->parent);
		}
	}

	node->left = node_optimize(node->left);
	node->right = node_optimize(node->right);

	return node;
}

static void tree_optimize(const map_t * map)
{
        tree_t * tree = container_of(map, tree_t, map);
	tree->root = node_optimize(tree->root);
}

static size_t tree_size(const map_t * map)
{
        tree_t * tree = container_of(map, tree_t, map);

	if (tree->root) {
		return tree->root->count;
	} else {
		return 0;
	}
}

void tree_init()
{
	INIT_ONCE();

}

static interface_map_t tree_t_map [] =
{
	INTERFACE_MAP_ENTRY(tree_t, iid_map_t, map),
};
static INTERFACE_IMPL_QUERY(map_t, tree_t, map)
static INTERFACE_OPS_TYPE(map_t) INTERFACE_IMPL_NAME(map_t, tree_t) = {
	INTERFACE_IMPL_QUERY_METHOD(map_t, tree_t)
	INTERFACE_IMPL_METHOD(destroy, tree_destroy)
	INTERFACE_IMPL_METHOD(walk, tree_walk)
	INTERFACE_IMPL_METHOD(walk_range, tree_walk_range)
	INTERFACE_IMPL_METHOD(put, tree_put)
	INTERFACE_IMPL_METHOD(get, tree_get)
	INTERFACE_IMPL_METHOD(optimize, tree_optimize)
	INTERFACE_IMPL_METHOD(remove, tree_remove)
	INTERFACE_IMPL_METHOD(iterator, tree_iterator)
	INTERFACE_IMPL_METHOD(size, tree_size)
};

map_t * tree_new(int (*comp)(map_key k1, map_key k2), treemode mode)
{
	tree_init();
	tree_t * tree = slab_alloc(trees);
	tree->map.ops = &tree_t_map_t;
	tree->root = 0;
	tree->mode = mode;
	tree->comp = (comp) ? comp : map_keycmp;

	return com_query(tree_t_map, countof(tree_t_map), iid_map_t, tree);
}

map_t * splay_new(int (*comp)(map_key k1, map_key k2))
{
	return tree_new(comp, TREE_SPLAY);
}

map_t * treap_new(int (*comp)(map_key k1, map_key k2))
{
	return tree_new(comp, TREE_TREAP);
}

static void tree_graph_node(node_t * node, int level)
{
	if (0==node) {
		return;
	}

	if (node->left) {
		tree_graph_node(node->left, level+1);
	}
	kernel_printk("%d\t", level);
	for(int i=0; i<level; i++) {
		kernel_printk("  ");
	}
	kernel_printk("%s\n", node->data);
	if (node->right) {
		tree_graph_node(node->right, level+1);
	}
}

#if 0
static void tree_walk_dump(void * p, void * key, void * data)
{
	kernel_printk("%s\n", data);

	if (p) {
		map_t * akmap = (map_t*)p;
		/* Add the data to the ak map */
		map_key * ak = (map_key*)map_arraykey2((intptr_t)akmap, *((char*)data));
		map_putpp(akmap, ak, data);
	}
}
#endif

void tree_test()
{
	tree_init();
	map_t * map = treap_new(map_strcmp);
	map_t * akmap = tree_new(map_arraycmp, 0);
	map_test(map, akmap);

	tree_graph_node(container_of(akmap, tree_t, map)->root, 0);
	map_optimize(akmap);
	tree_graph_node(container_of(akmap, tree_t, map)->root, 0);
}