-- | An implementation of Tarjan's UNION-FIND algorithm.  (Robert E
-- Tarjan. \"Efficiency of a Good But Not Linear Set Union Algorithm\", JACM
-- 22(2), 1975)
--
-- The algorithm implements three operations efficiently (all amortised
-- @O(1)@):
--
--  1. Check whether two elements are in the same equivalence class.
--
--  2. Create a union of two equivalence classes.
--
--  3. Look up the descriptor of the equivalence class.
-- 
-- The implementation is based on mutable references.  Each
-- equivalence class has exactly one member that serves as its
-- representative element.  Every element either is the representative
-- element of its equivalence class or points to another element in
-- the same equivalence class.  Equivalence testing thus consists of
-- following the pointers to the representative elements and then
-- comparing these for identity.
--
-- The algorithm performs lazy path compression.  That is, whenever we
-- walk along a path greater than length 1 we automatically update the
-- pointers along the path to directly point to the representative
-- element.  Consequently future lookups will be have a path length of
-- at most 1.
--
{-# OPTIONS_GHC -funbox-strict-fields #-}
module Data.UnionFind.IO
  ( Point, fresh, repr, union, union', equivalent, redundant,
    descriptor, setDescriptor, modifyDescriptor )
where

import Data.IORef
import Control.Monad ( when )
import Control.Applicative

-- | The abstract type of an element of the sets we work on.  It is
-- parameterised over the type of the descriptor.
newtype Point a = Pt (IORef (Link a)) deriving Point a -> Point a -> Bool
(Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool) -> Eq (Point a)
forall a. Point a -> Point a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Point a -> Point a -> Bool
== :: Point a -> Point a -> Bool
$c/= :: forall a. Point a -> Point a -> Bool
/= :: Point a -> Point a -> Bool
Eq

data Link a 
    = Info {-# UNPACK #-} !(IORef (Info a))
      -- ^ This is the descriptive element of the equivalence class.
    | Link {-# UNPACK #-} !(Point a)
      -- ^ Pointer to some other element of the equivalence class.
     deriving Link a -> Link a -> Bool
(Link a -> Link a -> Bool)
-> (Link a -> Link a -> Bool) -> Eq (Link a)
forall a. Link a -> Link a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Link a -> Link a -> Bool
== :: Link a -> Link a -> Bool
$c/= :: forall a. Link a -> Link a -> Bool
/= :: Link a -> Link a -> Bool
Eq

data Info a = MkInfo
  { forall a. Info a -> Int
weight :: {-# UNPACK #-} !Int
    -- ^ The size of the equivalence class, used by 'union'.
  , forall a. Info a -> a
descr  :: a
  } deriving Info a -> Info a -> Bool
(Info a -> Info a -> Bool)
-> (Info a -> Info a -> Bool) -> Eq (Info a)
forall a. Eq a => Info a -> Info a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Info a -> Info a -> Bool
== :: Info a -> Info a -> Bool
$c/= :: forall a. Eq a => Info a -> Info a -> Bool
/= :: Info a -> Info a -> Bool
Eq

-- | /O(1)/. Create a fresh point and return it.  A fresh point is in
-- the equivalence class that contains only itself.
fresh :: a -> IO (Point a)
fresh :: forall a. a -> IO (Point a)
fresh a
desc = do
  IORef (Info a)
info <- Info a -> IO (IORef (Info a))
forall a. a -> IO (IORef a)
newIORef (MkInfo { weight :: Int
weight = Int
1, descr :: a
descr = a
desc })
  IORef (Link a)
l <- Link a -> IO (IORef (Link a))
forall a. a -> IO (IORef a)
newIORef (IORef (Info a) -> Link a
forall a. IORef (Info a) -> Link a
Info IORef (Info a)
info)
  Point a -> IO (Point a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IORef (Link a) -> Point a
forall a. IORef (Link a) -> Point a
Pt IORef (Link a)
l)

-- | /O(1)/. @repr point@ returns the representative point of
-- @point@'s equivalence class.
--
-- This method performs the path compresssion.
repr :: Point a -> IO (Point a)
repr :: forall a. Point a -> IO (Point a)
repr point :: Point a
point@(Pt IORef (Link a)
l) = do
  Link a
link <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
l
  case Link a
link of
    Info IORef (Info a)
_ -> Point a -> IO (Point a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Point a
point
    Link pt' :: Point a
pt'@(Pt IORef (Link a)
l') -> do
      Point a
pt'' <- Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
pt'
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point a
pt'' Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
/= Point a
pt') (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        -- At this point we know that @pt'@ is not the representative
        -- element of @point@'s equivalent class.  Therefore @pt'@'s
        -- link must be of the form @Link r@.  We write this same
        -- value into @point@'s link reference and thereby perform
        -- path compression.
        Link a
link' <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
l'
        IORef (Link a) -> Link a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Link a)
l Link a
link'
      Point a -> IO (Point a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Point a
pt''

-- | Return the reference to the point's equivalence class's
-- descriptor.
descrRef :: Point a -> IO (IORef (Info a))
descrRef :: forall a. Point a -> IO (IORef (Info a))
descrRef point :: Point a
point@(Pt IORef (Link a)
link_ref) = do
  Link a
link <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
link_ref
  case Link a
link of
    Info IORef (Info a)
info -> IORef (Info a) -> IO (IORef (Info a))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IORef (Info a)
info
    Link (Pt IORef (Link a)
link'_ref) -> do
      Link a
link' <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
link'_ref
      case Link a
link' of
        Info IORef (Info a)
info -> IORef (Info a) -> IO (IORef (Info a))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IORef (Info a)
info
        Link a
_ -> Point a -> IO (IORef (Info a))
forall a. Point a -> IO (IORef (Info a))
descrRef (Point a -> IO (IORef (Info a)))
-> IO (Point a) -> IO (IORef (Info a))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
point

-- | /O(1)/. Return the descriptor associated with argument point's
-- equivalence class.
descriptor :: Point a -> IO a
descriptor :: forall a. Point a -> IO a
descriptor Point a
point = do
  Info a -> a
forall a. Info a -> a
descr (Info a -> a) -> IO (Info a) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (IORef (Info a) -> IO (Info a)
forall a. IORef a -> IO a
readIORef (IORef (Info a) -> IO (Info a))
-> IO (IORef (Info a)) -> IO (Info a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point a -> IO (IORef (Info a))
forall a. Point a -> IO (IORef (Info a))
descrRef Point a
point)

-- | /O(1)/. Replace the descriptor of the point's equivalence class
-- with the second argument.
setDescriptor :: Point a -> a -> IO ()
setDescriptor :: forall a. Point a -> a -> IO ()
setDescriptor Point a
point a
new_descr = do
  IORef (Info a)
r <- Point a -> IO (IORef (Info a))
forall a. Point a -> IO (IORef (Info a))
descrRef Point a
point
  IORef (Info a) -> (Info a -> Info a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Info a)
r ((Info a -> Info a) -> IO ()) -> (Info a -> Info a) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Info a
i -> Info a
i { descr :: a
descr = a
new_descr }

modifyDescriptor :: Point a -> (a -> a) -> IO ()
modifyDescriptor :: forall a. Point a -> (a -> a) -> IO ()
modifyDescriptor Point a
point a -> a
f = do
  IORef (Info a)
r <- Point a -> IO (IORef (Info a))
forall a. Point a -> IO (IORef (Info a))
descrRef Point a
point
  IORef (Info a) -> (Info a -> Info a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Info a)
r ((Info a -> Info a) -> IO ()) -> (Info a -> Info a) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Info a
i -> Info a
i { descr :: a
descr = a -> a
f (Info a -> a
forall a. Info a -> a
descr Info a
i) }

-- | /O(1)/. Join the equivalence classes of the points (which must be
-- distinct).  The resulting equivalence class will get the descriptor
-- of the second argument.
union :: Point a -> Point a -> IO ()
union :: forall a. Point a -> Point a -> IO ()
union Point a
p1 Point a
p2 = Point a -> Point a -> (a -> a -> IO a) -> IO ()
forall a. Point a -> Point a -> (a -> a -> IO a) -> IO ()
union' Point a
p1 Point a
p2 (\a
_ a
d2 -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
d2)

-- | Like 'union', but sets the descriptor returned from the callback.
-- 
-- The intention is to keep the descriptor of the second argument to
-- the callback, but the callback might adjust the information of the
-- descriptor or perform side effects.
union' :: Point a -> Point a -> (a -> a -> IO a) -> IO ()
union' :: forall a. Point a -> Point a -> (a -> a -> IO a) -> IO ()
union' Point a
p1 Point a
p2 a -> a -> IO a
update = do
  point1 :: Point a
point1@(Pt IORef (Link a)
link_ref1) <- Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
p1
  point2 :: Point a
point2@(Pt IORef (Link a)
link_ref2) <- Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
p2
  -- The precondition ensures that we don't create cyclic structures.
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point a
point1 Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
/= Point a
point2) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Info IORef (Info a)
info_ref1 <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
link_ref1
    Info IORef (Info a)
info_ref2 <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
link_ref2
    MkInfo Int
w1 a
d1 <- IORef (Info a) -> IO (Info a)
forall a. IORef a -> IO a
readIORef IORef (Info a)
info_ref1 -- d1 is discarded
    MkInfo Int
w2 a
d2 <- IORef (Info a) -> IO (Info a)
forall a. IORef a -> IO a
readIORef IORef (Info a)
info_ref2
    a
d2' <- a -> a -> IO a
update a
d1 a
d2
    -- Make the smaller tree a a subtree of the bigger one.  The idea
    -- is this: We increase the path length of one set by one.
    -- Assuming all elements are accessed equally often, this means
    -- the penalty is smaller if we do it for the smaller set of the
    -- two.
    if Int
w1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w2 then do
      IORef (Link a) -> Link a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Link a)
link_ref2 (Point a -> Link a
forall a. Point a -> Link a
Link Point a
point1)
      IORef (Info a) -> Info a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Info a)
info_ref1 (Int -> a -> Info a
forall a. Int -> a -> Info a
MkInfo (Int
w1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w2) a
d2')
     else do
      IORef (Link a) -> Link a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Link a)
link_ref1 (Point a -> Link a
forall a. Point a -> Link a
Link Point a
point2)
      IORef (Info a) -> Info a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Info a)
info_ref2 (Int -> a -> Info a
forall a. Int -> a -> Info a
MkInfo (Int
w1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w2) a
d2')

-- | /O(1)/. Return @True@ if both points belong to the same
-- | equivalence class.
equivalent :: Point a -> Point a -> IO Bool
equivalent :: forall a. Point a -> Point a -> IO Bool
equivalent Point a
p1 Point a
p2 = Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Point a -> Point a -> Bool)
-> IO (Point a) -> IO (Point a -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
p1 IO (Point a -> Bool) -> IO (Point a) -> IO Bool
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Point a -> IO (Point a)
forall a. Point a -> IO (Point a)
repr Point a
p2

-- | /O(1)/. Returns @True@ for all but one element of an equivalence
-- class.  That is, if @ps = [p1, .., pn]@ are all in the same
-- equivalence class, then the following assertion holds.
-- 
-- > do rs <- mapM redundant ps
-- >    assert (length (filter (==False) rs) == 1)
-- 
-- It is unspecified for which element function returns @False@, so be
-- really careful when using this.
redundant :: Point a -> IO Bool
redundant :: forall a. Point a -> IO Bool
redundant (Pt IORef (Link a)
link_r) = do
  Link a
link <- IORef (Link a) -> IO (Link a)
forall a. IORef a -> IO a
readIORef IORef (Link a)
link_r
  case Link a
link of
    Info IORef (Info a)
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    Link Point a
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True