Spoj ORDERSET – CPP solution

Problem: http://www.spoj.com/problems/ORDERSET/

The problem can be very easily solved using std::set. The problem is that std::set would be slow for this problem.
One can solve it with a custom implementation of a balanced binary search tree. I used Treap.

#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <set>
#include <algorithm>
typedef long Key;
typedef long Priority;

struct Treap;
Treap * root;
Treap * nil;

struct Treap {
    Key key;
    Priority priority;
    Treap * left;
    Treap * right;
    int size;
    Treap() : size(0) {};
    Treap(Key _key, Priority _priority) :
        key(_key), priority(_priority), left(nil), right(nil), size(1) {};
    Treap(Key _key, Priority _priority, int _size) :
        key(_key), priority(_priority), left(nil), right(nil), size(_size) {};
};

void print(Treap * node, int lvl, bool l) {
    if(node == nil) return;
    if (l) printf("Left, Key: %ld, Priority: %ld, Level: %d, Size: %dn", node->key, node->priority, lvl, node->size);
    if (!l) printf("Right, Key: %ld, Priority: %ld, Level: %d, Size: %dn", node->key, node->priority, lvl, node->size);
    print(node->left, lvl+1, 1);
    print(node->right, lvl+1, 0);
}

void updateSize(Treap * cur) {
    cur->size = cur->right->size + cur->left->size + 1;
}

void rotateRight(Treap * cur, Treap * leftChild) {
    cur->left = leftChild->right;
    leftChild->right = cur;
    updateSize(cur);
    updateSize(leftChild);
}

void rotateLeft(Treap * cur, Treap * rightChild) {
    cur->right = rightChild->left;
    rightChild->left = cur;
    updateSize(cur);
    updateSize(rightChild);
}

Treap * add(Treap * cur, Treap * node) {
    if (cur->key < node->key) {
        // go right
        Treap * rht = cur->right;
        if (rht == nil) {
            rht = node;
        } else {
            rht = add(rht, node);
            if (rht == nil) return nil;
        }
        cur->right = rht;
        if (rht->priority > cur->priority) {
            // left rotation
            rotateLeft(cur, rht);
            return rht;
        }
        updateSize(cur);
        return cur;
    } else if (cur->key > node->key) {
        // go left
        Treap * lft = cur->left;
        if (lft == nil) {
           lft = node;
        } else {
            lft = add(lft, node);
            if (lft == nil) return nil;
        }
        cur->left = lft;
        if (lft->priority > cur->priority) {
            // right rotation
            rotateRight(cur, lft);
            return lft;
        }
        updateSize(cur);
        return cur;
    } else {
        return cur;
    }
}

void add(Key key) {
    Priority priority = rand();
    Treap * node = new Treap(key, priority);
    if (root == nil) {
        root = node;
    } else {
        root = add(root, node);
    }
}

Treap * rem(Treap * cur, Key key) {
    if (cur == nil) return nil;
    if (cur->key == key) {
        if (cur->left == nil && cur->right == nil) {
            // destroy
            delete cur;
            return nil;
        } else if (cur->left == nil || cur->right->priority > cur->left->priority) {
            // rotate left
            Treap * rht = cur->right;
            rotateLeft(cur, rht);
            rht->left = rem(cur, key);
            cur = rht;
        } else {
            // rotate right
            Treap * lft = cur->left;
            rotateRight(cur, lft);
            lft->right = rem(cur, key);
            cur = lft;
        }
    } else if (cur->key > key) {
        cur->left = rem(cur->left, key);
    } else {
        cur->right = rem(cur->right, key);
    }
    updateSize(cur);
    return cur;
}

void rem(Key key) {
    root = rem(root, key);
}

Treap * kth(Treap * cur, int k) {
    if (cur == nil) {
        return nil;
    }
    //printf("~%dn", k);
    int tmp = cur->left->size + 1;
    if (tmp == k) {
        return cur;
    } else if ( tmp < k) {
       // printf("right %dn", k-tmp);
        return kth(cur->right, k - tmp);
    } else {
        //printf("left %dn", k);
        return kth(cur->left, k);
    }
}

int cnt(Treap * cur, Key key) {
    if (cur == nil) return 0;
    if (cur->key == key) {
        return cur->left->size;
    } else if (cur->key < key) {
        return cur->left->size + 1 + cnt(cur->right, key);
    } else {
        return cnt(cur->left, key);
    }
}

void test(Treap * cur) {
    if (cur == nil) return;
    int sz = 1 + cur->left->size + cur->right->size;
    if (sz != cur->size) {
        printf("SIZE ERRORn");
        exit(1);
    }
    if (cur->left->priority > cur->priority) {
        printf("LEFT PRIORITY ERRORn");
        exit(1);
    }
    if (cur->right->priority > cur->priority) {
        printf("RIGHT PRIORITY ERRORn");
        exit(1);
    }
    if (cur->right != nil && cur->right->key < cur->key) {
        printf("RIGHT KEY ERRORn");
        exit(1);
    }
    if (cur->left != nil && cur->left->key > cur->key) {
        printf("LEFT KEY ERRORn");
        exit(1);
    }
    test(cur->left);
    test(cur->right);
}

std::set<int> ss;

void test() {
    if (root->size != ss.size()) {
        printf("SET SIZE ERRORn");
        exit(1);
    }
    test(root);
}

void gen() {
    using namespace std;
    FILE * out = fopen("orderset.test", "w+");
    FILE * ans = fopen("orderset.ans", "w+");
    int n = 50000;
    fprintf(out, "%dn", n);
    set<int> ss;
    set<int>::iterator it;
    for (int i = 0; i < n; ++i) {
        int op = rand()%100;
        if (op >= 0 && op < 15) {
            // delete
            if (ss.size() > 0) {
                it = ss.begin();
                advance(it, rand()%(ss.size()-1));
                int key = *it;
                fprintf(out, "D %dn", key);
                ss.erase(key);
            } else {
                --i;
            }
        } else if (op >= 15 && op < 60) {
            // insert
            int key = rand();
            ss.insert(key);
            fprintf(out, "I %dn", key);
        } else if (op >= 60 && op < 85) {
            // count
            int key = rand();
            fprintf(out, "C %dn", key);
            it = ss.lower_bound(key);
            fprintf(ans, "%dn", distance(ss.begin(),it));
        } else {
            // kth
            int num = rand() % (ss.size() + rand()%50) + 1;
            fprintf(out, "K %dn", num);
            if (num > ss.size()) {
                fprintf(ans, "invalidn");
            } else {
                it = ss.begin();
                num -= 1;
                advance(it, num);
                fprintf(ans, "%dn", *it);
            }
        }
    }
    fclose(out);
    fclose(ans);
}

//#define DEBUG

int main() {
    srand(time(NULL));
    #ifdef DEBUG
    gen();
    return 0;
    #endif
    nil = new Treap();
    root = nil;
    int q;
    scanf("%d", &q);
    while (q--) {
        char ch[2];
        Key key;
        scanf("%s%ld", ch, &key);
        if (ch[0] == 'I') {
            add(key);
            //print(root, 0, 0);
        } else if (ch[0] == 'D') {
            rem(key);
        } else if (ch[0] == 'C') {
            printf("%dn", cnt(root, key));
        } else if (ch[0] == 'K') {
            if (key > root->size) {
                printf("invalidn");
            } else {
                printf("%ldn", kth(root, key)->key);
            }
        }
        //printf("n");
        //print(root, 0, 0);
        //printf("n");
    }
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *