{-# LANGUAGE GADTs, StandaloneDeriving #-}

import Control.Parallel (par)
import Data.Bits (shiftL)
import Data.Int (Int64)
import Data.List (sort)

-- A list of these defines a HORNSAT instance with variables
-- indexed by type var.  The "Free" instances declare variables
-- and the "Horn" instances declare constraints.
data Clause var where
  Horn :: (Eq var) => ![var] -> !(Maybe var) -> Clause var
  Free :: (Eq var) => !var -> Clause var
deriving instance Eq var => Eq (Clause var)

-- Orders clauses using a simple heuristic so that the
-- best candidates for elimination will sort first.
instance (Eq var) => Ord (Clause var) where
  (<=) (Horn a b) (Horn c d) = length(a)+isjust(b) <=
                               length(c)+isjust(d)
    where
      isjust Nothing = 0
      isjust (Just _) = 1
  (<=) _ (Free _) = True
  (<=) (Free _) (Horn _ _) = False

-- Counts the number of solutions to a HORNSAT instance.
-- The dynamic programming parallelizes well, so you should
-- enable multiple cores by compiling with options:
--    -threaded -rtsopts -with-rtsopts="-N"
countSolutions :: (Eq var) => [Clause var] -> Int64
countSolutions cs' = case sort cs' of
    (Horn [] Nothing:cs) -> 0
    (Horn antes Nothing:cs) -> remaining antes cs
    (Horn antes (Just post):cs) ->
      let sol1 = (countSolutions (forceTrue post cs))
          sol2 = remaining antes (forceFalse post cs)
       in sol1 `par` (sol2 `seq` sol1 + sol2)
    (Free _:cs) -> shiftL 2 (length cs)
    [] -> 1
  where
    remaining (ante:antes) cs = let
      sol1 = (countSolutions (forceFalse ante cs))
      sol2 = remaining antes (forceTrue ante cs)
     in sol1 `par` (sol2 `seq` sol1 + sol2)
    remaining [] cs = 0
    forceTrue a cs = concatMap (forceT a) cs
      where forceT k (Horn antes (Just post)) =
              if post==k then []
                         else [Horn (filter (/=k) antes) (Just post)]
            forceT k (Horn antes Nothing) =
              [Horn (filter (/=k) antes) Nothing]
            forceT k (Free var) = if var==k then [] else [Free var]
    
    forceFalse a cs = concatMap (forceF a) cs
      where forceF k (Horn antes (Just post)) =
              if post==k then [Horn antes Nothing]
                         else if (any (==k) antes)
                                then []
                                else [Horn antes (Just post)]
            forceF k (Horn antes Nothing) =
              if (any (==k) antes) then [] else [Horn antes Nothing]
            forceF k (Free var) = if var==k then [] else [Free var]

-- Generates a HORNSAT instance corresponding to the problem of
-- finding a weak factorization system on the category [n] x [m].
prodcat :: Int -> Int -> [Clause Int64]
prodcat n m = let
    ints :: [Int64]
    ints = (1 : map (1+) ints)
    is = take n ints
    js = take m ints
    compress (i,j,k,l) = i+ nn*(j+ mm*(k+ nn*l))
      where nn = fromIntegral n
            mm = fromIntegral m
    lefts = do
      i <- is
      j <- js
      k <- filter (>=i) is
      l <- filter (>=j) js
      let vert = if ((l>j)&&((k,l-1)/=(i,j)))
                   then [Horn [compress (i,j,k,l)] $
                              Just (compress (i,j,k,l-1))]
                   else if ((j==l)&&(j>1)&&((k,l-1)/=(i,j-1)))
                     then [Horn [compress (i,j,k,l)] $
                                Just (compress (i,j-1,k,l-1))]
                     else []
          horz = if ((k>i)&&((k-1,l)/=(i,j)))
                   then [Horn [compress (i,j,k,l)] $
                              Just (compress (i,j,k-1,l))]
                   else if ((i==k)&&(i>1)&&((k-1,l)/=(i-1,j)))
                     then [Horn [compress (i,j,k,l)] $
                                Just (compress (i-1,j,k-1,l))]
                     else []
       in vert ++ horz
    comp = do
      i <- is
      j <- js
      k <- filter (>=i) is
      l <- filter (>=j) js
      m <- filter (>=k) is
      n <- filter (>=l) js
      if (((i,j)/=(k,l))&&((k,l)/=(m,n)))
        then [Horn [compress (i,j,k,l), compress (k,l,m,n)] $
                   Just (compress (i,j,m,n))]
        else []
    free = do
      i <- is
      j <- js
      k <- filter (>=i) is
      l <- filter (>=j) js
      if ((i,j)/=(k,l)) then [Free $ compress (i,j,k,l)] else []
  in lefts ++ comp ++ free

a092450 n m = countSolutions $ prodcat n m