{-# LANGUAGE GADTs, StandaloneDeriving #-}

import Control.Parallel (par)
import Data.Bits (shiftL)
import Data.Int (Int64)
import Data.List (sort)

-- A list of these defines a HORNSAT instance with variables
-- indexed by type var.  The "Free" instances declare variables
-- and the "Horn" instances declare constraints.
data Clause var where
  Horn :: (Eq var) => ![var] -> !(Maybe var) -> Clause var
  Free :: (Eq var) => !var -> Clause var
deriving instance Eq var => Eq (Clause var)

-- Orders clauses using a simple heuristic so that the
-- best candidates for elimination will sort first.
instance (Eq var) => Ord (Clause var) where
  (<=) (Horn a b) (Horn c d) = length(a)+isjust(b) <=
                               length(c)+isjust(d)
    where
      isjust Nothing = 0
      isjust (Just _) = 1
  (<=) _ (Free _) = True
  (<=) (Free _) (Horn _ _) = False

-- Counts the number of solutions to a HORNSAT instance.
-- The dynamic programming parallelizes well, so you should
-- enable multiple cores by compiling with options:
--    -threaded -rtsopts -with-rtsopts="-N"
countSolutions :: (Eq var) => [Clause var] -> Int64
countSolutions cs' = case sort cs' of
    (Horn [] Nothing:cs) -> 0
    (Horn antes Nothing:cs) -> remaining antes cs
    (Horn antes (Just post):cs) ->
      let sol1 = (countSolutions (forceTrue post cs))
          sol2 = remaining antes (forceFalse post cs)
       in sol1 `par` (sol2 `seq` sol1 + sol2)
    (Free _:cs) -> shiftL 2 (length cs)
    [] -> 1
  where
    remaining (ante:antes) cs = let
      sol1 = (countSolutions (forceFalse ante cs))
      sol2 = remaining antes (forceTrue ante cs)
     in sol1 `par` (sol2 `seq` sol1 + sol2)
    remaining [] cs = 0
    forceTrue a cs = concatMap (forceT a) cs
      where forceT k (Horn antes (Just post)) =
              if post==k then []
                         else [Horn (filter (/=k) antes) (Just post)]
            forceT k (Horn antes Nothing) =
              [Horn (filter (/=k) antes) Nothing]
            forceT k (Free var) = if var==k then [] else [Free var]
    
    forceFalse a cs = concatMap (forceF a) cs
      where forceF k (Horn antes (Just post)) =
              if post==k then [Horn antes Nothing]
                         else if (any (==k) antes)
                                then []
                                else [Horn antes (Just post)]
            forceF k (Horn antes Nothing) =
              if (any (==k) antes) then [] else [Horn antes Nothing]
            forceF k (Free var) = if var==k then [] else [Free var]

-- Generates a HORNSAT instance corresponding to the problem of
-- finding a weak factorization system on the category Hom([n], [m+1]).
hom :: Int -> Int -> [Clause Int64]
hom n m = let
    ints :: [Int64]
    ints = (1 : map (1+) ints)
    is = take (n+m) ints
    j = min n m
    nm = fromIntegral (n+m)
    compress (s',t') = case (s',t') of
      ((s:ss), (t:ts)) -> s + nm*(t + nm*compress (ss,ts))
      ([], []) -> 0
    objs = let gen = \x -> let hd = case x of { (y:_) -> y; [] -> 0 }
                            in do i <- filter (>hd) is
                                  return (i:x)
            in foldl (>>=) [[]] (replicate j gen)
    morphismPredicate s t = let z = zip s t
      in (all (uncurry (<=)) z) && (not (all (uncurry (==)) z))
    mors = do
      s <- objs
      t <- objs
      if morphismPredicate s t then [(s,t)] else []
    free = do
      (s,t) <- mors
      [Free $ compress (s,t)]
    comp = do
      (s,t) <- mors
      u <- objs
      if morphismPredicate t u
        then [Horn [compress (s,t), compress (t,u)] $
                   Just (compress (s,u))]
        else []
    decNth n list = let (hd, tl) = splitAt n list
                     in hd ++ (((head tl)-1):(tail tl))
    lefts = do
      (s,t) <- mors
      ix <- take j $ let nats = (0:map (1+) nats) in nats
      (s',t') <- if (t!!ix > s!!ix)
                   then if ((t!!ix)-1 >
                            if (ix+1<j) then t!!(ix+1) else 0)
                     then return (s, decNth ix t)
                     else []
                   else if ((t!!ix)-1 >
                            if (ix+1<j) then t!!(ix+1) else 0)
                     then return (decNth ix s, decNth ix t)
                     else []
      if s' == t'
        then []
        else [Horn [compress (s,t)] $ Just (compress (s',t'))]
  in lefts ++ comp ++ free

a091378 n m = countSolutions $ hom n m