Spoj QTREE – CPP solution

Problem: http://www.spoj.com/problems/QTREE/

The keys to solving this problem are:
– some basic graph theory (traversing graphs) (easy)
– LCA on a tree (medium)
– Heavy-light decomposition (medium-hard)
– Segment tree (medium)

In general, I had some bugs in my code and it took some time to debug. You can see some asserts and generating tests to see what’s going wrong with my solution. :)

Note that this is the O(nlogn + Qlogn * logn) where Q is the number of queries.
I could not get the O(nlogn + Q * logn) to work.

If you are willing to try, note that we are using the segment tree to check the maximum over a given chain. From one node to a given parent node somewhere in the tree, there are at most O(logn) chains. The thing is that possibly for expect the first and last chain, we are querying over the whole interval which means that for each chain we can just keep a maximum value so that we are able to retrieve it in O(1) and update it O(1) time. This will lead to a
O(logn) per query.

#include <cstdio>
#include <iostream>
#include <fstream>
#include <cstdlib>
#include <ctime>
#include <vector>
#include <iostream>
#include <cassert>
using namespace std;

struct Edge {
    int u;
    int v;
    int cost;
    Edge() {};
    Edge(int _u, int _v, int _cost) : u(_u), v(_v), cost(_cost) {};
    int getNb(int V) {
        return V == v ? u : v;
    }
};
int n;
vector<int> nb[10001];
int childSz[10001];
int parent[10001][32];
int depth[10001];
Edge edges[10001];
int LOG;
int numChains;
char isSpecial[10001];
int generatedArrayIDX;
int generatedArray[10001];
int chainHead[10001];
int chainSize[10001];
int vertexChain[10001];
int vertexPos[10001];
int tree[80001];

int slow(int u, int c, int p, int r) {
    if (u == r) {
        return c;
    }
    int res = 0;
    for (int i = 0; i < nb[u].size();i++) {
        int v = edges[nb[u][i]].getNb(u);
        if (p == v) continue;
        int q = slow(v, edges[nb[u][i]].cost, u, r);
        if (q) q = max(q, edges[nb[u][i]].cost);
        res = max(res, q);
    }
    return res;
}

void reset() {
    numChains = 0;
    generatedArrayIDX = 0;
    for (int i = 0; i<=n; i++) {
        nb[i].clear();
        childSz[0];
        parent[i][0] = 0;
        depth[i] = 0;
        isSpecial[i] = 0;
        chainHead[i] = 0;
        generatedArray[i] = 0;
        chainSize[i] = 0;
        vertexChain[i] = 0;
        vertexPos[i] = 0;
    }
}

void dfs(int u, int p) {
    childSz[u] = 1;
    parent[u][0] = p;
    depth[u] = depth[p]+1;
    for (int i = 0; i < nb[u].size(); i++) {
        int v = edges[nb[u][i]].getNb(u);
        if (v == p) continue;
        dfs(v,u);
        childSz[u] += childSz[v];
    }
}

void hld(int u,int e) {
    if (chainHead[numChains] == 0) {
        // create new chain
        chainHead[numChains] = u;
    }
    vertexChain[u] = numChains;
    vertexPos[u] = generatedArrayIDX;
    chainSize[numChains]+=1;
    if (u != 0) {
        generatedArray[generatedArrayIDX++] = e;
    }

    int special = 0;
    int edge;
    for (int i = 0; i < nb[u].size(); i++) {
        int v =edges[nb[u][i]].getNb(u);
        if (v == parent[u][0]) continue;
        if (childSz[special] < childSz[v]) {
            special = v;
            edge = edges[nb[u][i]].cost;
        }
    }
    if (special == 0) return;
    isSpecial[special] = 1;
    hld(special, edge);
    for (int i = 0; i < nb[u].size(); i++) {
        int v = edges[nb[u][i]].getNb(u);
        if (v == parent[u][0] || v == special) continue;
        ++numChains;
        hld(v,edges[nb[u][i]].cost);
    }
}


void build(int idx, int sfrom, int sto) {
    if (sfrom == sto) {
        tree[idx] = generatedArray[sfrom];
        return;
    }
    int mid = (sfrom+sto)>>1;
    int left = idx<<1;
    int right = left+1;
    build(left, sfrom, mid);
    build(right, mid+1, sto);
    tree[idx] = max(tree[left], tree[right]);
}

void init() {
    dfs(1,0);
    LOG = 0;
    while ((1<<LOG) <= n) ++LOG;
    for (int i = 1; i<=LOG; i++) {
        for (int u = 1; u<=n; u++) {
            parent[u][i] = parent[parent[u][i-1]][i-1];
        }
    }
    hld(1, 0);
    build(1, 0, generatedArrayIDX-1);
}

int lca(int u, int v) {
    int L = LOG;
    if (depth[u] < depth[v]) swap(u,v);
    while (depth[u] != depth[v]) {
        if (depth[u] - (1<<L) >= depth[v]) u = parent[u][L];
        --L;
        if (L < 0) L=0;
    }
    L = LOG;
    while (u != v) {
        if (parent[u][L] != parent[v][L] || L == 0) {
            u = parent[u][L];
            v = parent[v][L];
        }
        --L;
        if (L < 0) L = 0;
    }
    return u;
}

int query(int idx, int sfrom, int sto, int qfrom, int qto) {
    if (sfrom == qfrom && sto == qto) {
        return tree[idx];
    }
    int mid = (sfrom+sto)>>1;
    int left = idx<<1;
    int right = left+1;
    if (qto <= mid) {
        return query(left, sfrom, mid, qfrom, qto);
    } else if (qfrom > mid) {
        return query(right, mid+1, sto, qfrom, qto);
    } else {
        return max(query(left, sfrom, mid, qfrom, mid),
            query(right, mid+1, sto, mid+1, qto));
    }
}

void update(int idx, int sfrom, int sto, int pos) {
    if (sfrom == sto) {
        tree[idx] = generatedArray[pos];
        return;
    }
    int mid = (sfrom+sto)>>1;
    int left = idx<<1;
    int right = left+1;
    if (pos <= mid) {
        update(left, sfrom, mid, pos);
    } else {
        update(right, mid+1, sto, pos);
    }
    tree[idx] = max(tree[left], tree[right]);
}

int subquery(int TOP, int BOTTOM) {
    int ans = -1;
    if (TOP == BOTTOM) return 0;
    while (true) {
        if (vertexChain[TOP] == vertexChain[BOTTOM]) {
            int topPos = vertexPos[TOP];
            int botPos = vertexPos[BOTTOM];
            if (topPos == botPos) return ans;
            topPos += 1;
            return max(
                ans,
                query(1, 0, generatedArrayIDX-1, topPos, botPos)
            );
        } else {
            int botPos = vertexPos[BOTTOM];
            int topPos = vertexPos[chainHead[vertexChain[BOTTOM]]];

            ans = max(ans, query(1, 0, generatedArrayIDX-1, topPos, botPos));

            BOTTOM = parent[chainHead[vertexChain[BOTTOM]]][0];
        }
    }
}

int query(int u, int v) {
    int w = lca(u,v);
    return max(subquery(w, u), subquery(w,v));
}

int update(int u, int c) {
    Edge & e = edges[u-1];
    e.cost = c;
    int CHILD;
    if (e.u == parent[e.v][0]) {
        CHILD = e.v;
    } else {
        CHILD = e.u;
    }
    generatedArray[vertexPos[CHILD]] = c;
    update(1,0,generatedArrayIDX-1, vertexPos[CHILD]);
}

void genRandom() {
    int t = 1;
    ofstream out("qtree.in");
    out << t << endl;
    srand(time(NULL));
    int FROM = 4000, TO = 4001;
    for (int i = 0; i <t; i++) {
        int v = rand()%(TO-FROM) + FROM;
        out << v << endl;
        vector<int> used;
        used.reserve(v);
        used.push_back(1);
        for (int j = 2; j<=v; j++) {
            int cost = rand()%100 + 5;
            int u = used[rand()%used.size()];
            used.push_back(j);
            out << u << " " << j << " " << cost << endl;
        }
        int op = 2000;
        for (int j = 0; j < op; j++) {
            int M = rand()%2;
            if (M == 0) {
                // query
                int A = rand()%v + 1;
                int B = rand()%v + 1;
                out << "QUERY " << A << " " << B << endl;
            } else {
                // update
                int A = rand()%(v-1) + 1;
                int B = rand()%10000 + 5;
                out << "CHANGE " << A << " " << B << endl;
            }
        }
        out << "DONE" << endl;
    }
    out.close();
}

//#define DEBUG2

int main() {
    #ifdef DEBUG
        genRandom();
        return 0;
    #endif
    int t;
    scanf("%d",&t);
    while (t--) {
        scanf("%d",&n);
        for (int i = 1; i < n; i++) {
            int a,b;
            long c;
            scanf("%d%d%ld", &a,&b,&c);
            edges[i-1] = Edge(a,b,c);
            nb[a].push_back(i-1);
            nb[b].push_back(i-1);
        }
        init();
        int a,b,c;
        char buf[7];
        while (1) {
            scanf("%s",buf);
            if (buf[0] == 'D') break;
            if (buf[0] == 'Q') {
                scanf("%d%d",&a,&b);
                printf("%dn", query(a,b));
                #ifdef DEBUG2
                    assert(query(a,b) == slow(a,0,0,b));
                #endif
            } else if (buf[0] == 'C') {
                scanf("%d%d",&a,&c);
                update(a,c);
            }
        }
        reset();
    }
    return 0;
}