with(combinat);

pet_cycleind_symm :=
proc(n)
local p, s;
option remember;

    if n=0 then return 1; fi;

    expand(1/n*add(a[l]*pet_cycleind_symm(n-l), l=1..n));
end;

pet_flatten_term :=
proc(varp)
local terml, d, cf, v;

    terml := [];

    cf := varp;
    for v in indets(varp) do
        d := degree(varp, v);
        terml := [op(terml), seq(v, k=1..d)];
        cf := cf/v^d;
    od;

    [cf, terml];
end;


pet_cycleind_grid :=
proc(N)
option remember;

    if type(N, even) then
        cind :=
        1/8*(a[1]^(N^2)+3*a[2]^(N^2/2)+
             2*a[1]^N*a[2]^((N^2-N)/2) + 2*a[4]^(N^2/4));
    else
        cind :=
        1/8*(a[1]^(N^2)+4*a[1]^N*a[2]^((N^2-N)/2)+
             a[1]*a[2]^((N^2-1)/2) + 2*a[1]*a[4]^((N^2-1)/4));
    fi;

    cind;
end;

gridcols :=
proc(N, Q)
option remember;
local idx_slots, idx_cols, res, a, b,
    flat_a, flat_b, cyc_a, cyc_b, len_a, len_b, p, q;

    if N > 1 then
        idx_slots := pet_cycleind_grid(N);
    else
        idx_slots := [a[1]];
    fi;

    if Q > 1 then
        idx_cols := pet_cycleind_symm(Q);
    else
        idx_cols := [a[1]];
    fi;


    res := 0;

    for a in idx_slots do
        flat_a := pet_flatten_term(a);

        for b in idx_cols do
            flat_b := pet_flatten_term(b);

            p := 1;
            for cyc_a in flat_a[2] do
                len_a := op(1, cyc_a);
                q := 0;

                for cyc_b in flat_b[2] do
                    len_b := op(1, cyc_b);

                    if len_a mod len_b = 0 then
                        q := q + len_b;
                    fi;
                od;

                p := p*q;
            od;

            res := res + p*flat_a[1]*flat_b[1];
        od;
    od;

    res;
end;