module CP where

import qualified Data.IntSet as IS
import qualified Data.IntMap.Strict as IM
import Data.Maybe
import Control.Monad
import Term
import qualified Rule as R
import qualified CPIndexing as CPI
import Signature
import Equation

-- critical peak
-- represents _left <-id1- top -id2-> right
data CP = CP {
  _left :: Term,
  _id1 :: Int, -- non-root step 
  _id2 :: Int, -- root step
  _right :: Term,
  _top :: Term,
  _depth :: Int
}

-- assumes that l1 = r1 and l2 = r2 are suitably renamed
ecp' :: (Term -> Term -> Bool) -> (Term, Term, R.Orientation) -> (Term, Term, R.Orientation) -> [(Term, Term, Term)]
ecp' gt (l1, r1, o1) (l2, r2, o2) = do
  p <- functionPositions l2
  guard (p /= [] || not (variant (l1, r1) (l2, r2)))
  sigma <- maybeToList (mgu l1 (subtermAt l2 p))
  let subst t = substitute t sigma
  guard (o1 == R.Oriented || not (gt (subst r1) (subst l1)))
  guard (o2 == R.Oriented || not (gt (subst r2) (subst l2)))
  let left = subst (replace l2 r1 p)
  let right = subst r2
  guard (left /= right)
  return (left, right, subst l2)

rename :: (Term, Term) -> (Term, Term) -> ((Term, Term), (Term, Term))
rename (l1, r1) (l2, r2) = ((rho1 l1, rho1 r1), (rho2 l2, rho2 r2))
  where
    vs1 = IS.union (variables l1) (variables r1)
    vs2 = IS.union (variables l2) (variables r2) 
    rho1 t = substitute t (IM.fromList (zip (IS.toList vs1) (map V [0..])))
    rho2 t = substitute t (IM.fromList (zip (IS.toList vs2) (map V [IS.size vs1..])))

-- extended critical pair between l1 = r1 and l2 = r2
-- * overlaps between flipped equations are also considered 
-- * l2 = r2 is used for root step
ecp :: (Term -> Term -> Bool) ->  R.Rule -> R.Rule -> [CP]
ecp gt rl1 rl2 = do
  rl1' <- possible l1 r1 (R._orientation rl1)
  rl2' <- possible l2 r2 (R._orientation rl2)
  let d = max (R._depth rl1) (R._depth rl2) + 1
  (left, right, top) <- ecp' gt rl1' rl2'
  return (CP {
    _left = left,
    _id1 = R._id rl1,
    _id2 = R._id rl2,
    _right = right,
    _top = top,
    _depth = d
  })
  where
    ((l1, r1), (l2, r2)) = rename (R._lhs rl1, R._rhs rl1) (R._lhs rl2, R._rhs rl2)
    possible l r R.Oriented = [ (l, r, R.Oriented) ]
    possible l r R.Unoriented = [ (l, r, R.Unoriented), (r, l, R.Unoriented) ]

-- p: position of l2
-- l2 = r2 is applied at root
-- this function does not assume renaming (TODO: renaming every single time is not optimal)
ecpAt :: (Term -> Term -> Bool) -> Position ->
          (Term, Term, R.Orientation, Int, Int) -> (Term, Term, R.Orientation, Int, Int) ->
          Maybe CP
ecpAt gt p (l1, r1, o1, i1, d1) (l2, r2, o2, i2, d2) = do
  guard (p /= [] || not (variant (l1', r1') (l2', r2')))
  sigma <- mgu l1' (subtermAt l2' p)
  let subst t = substitute t sigma
  guard (o1 == R.Oriented || not (gt (subst r1') (subst l1')))
  guard (o2 == R.Oriented || not (gt (subst r2') (subst l2')))
  let left = subst (replace l2' r1' p)
  let right = subst r2'
  guard (left /= right)
  return (CP {
    _left = left,
    _id1 = i1,
    _id2 = i2,
    _right = right,
    _top = subst l1',
    _depth = max d1 d2 + 1
  })
  where
    ((l1', r1'), (l2', r2')) = rename (l1, r1) (l2, r2) 

ecpWithIndex :: Signature -> (Term -> Term -> Bool) -> IM.IntMap R.Rule -> CPI.Index -> CPI.Index -> R.Rule -> [CP]
ecpWithIndex sig gt rules idx1 idx2 rl =
  (if R.oriented rl
    then []
    else [ cp | m <- retrieve (R._rhs rl), -- case rl is applied below or at root, right to left (including overlay)
                let rl' = rules IM.! (getId m),
                cp <- maybeToList (ecpAt gt (getPos m) (tupleRL rl) (orient m rl')) ] ++
         [ cp | (p, u) <- nonRootFunctionPositions' (R._rhs rl), -- case rl is applied at root, right to left (no overlay)
                m <- retrieveRoot u,
                let rl' = rules IM.! (getId m),
                cp <- maybeToList (ecpAt gt p (orient m rl') (tupleRL rl)) ]) ++
  -- case rl is applied below or at root, left to right (including overlay)
  [ cp | m <- retrieve (R._lhs rl),
         let rl' = rules IM.! (getId m),
         cp <- maybeToList (ecpAt gt (getPos m) (tupleLR rl) (orient m rl')) ] ++
  -- case rl is applied at root, left to right (no overlay)
  [ cp | (p, u) <- nonRootFunctionPositions' (R._lhs rl),
         m <- retrieveRoot u,
         let rl' = rules IM.! (getId m),
         cp <- maybeToList (ecpAt gt p (orient m rl') (tupleLR rl)) ]
  where
    getId = either fst fst
    getPos = either snd snd
    tupleLR rule = (R._lhs rule, R._rhs rule, R._orientation rule, R._id rule, R._depth rule)
    tupleRL rule = (R._rhs rule, R._lhs rule, R.Unoriented, R._id rule, R._depth rule)
    orient (Left _) rule = tupleLR rule
    orient (Right _) rule = tupleRL rule
    retrieve u = CPI.retrieve sig u idx1 ++ CPI.retrieve sig u idx2
    retrieveRoot u = CPI.retrieveRoot sig u idx1 ++ CPI.retrieveRoot sig u idx2
