11{-# LANGUAGE BlockArguments #-}
22{-# LANGUAGE CPP #-}
33{-# LANGUAGE DataKinds #-}
4+ {-# LANGUAGE GADTs #-}
5+ {-# LANGUAGE LambdaCase #-}
46{-# LANGUAGE NamedFieldPuns #-}
57{-# LANGUAGE PatternSynonyms #-}
68{-# LANGUAGE TupleSections #-}
9+ {-# LANGUAGE TypeApplications #-}
710
811module IfSat.Plugin
912 ( plugin )
1013 where
1114
1215-- base
1316import Control.Monad
14- ( filterM )
17+ ( when )
1518import Data.Foldable
16- ( for_ )
19+ ( traverse_ )
1720import Data.Maybe
1821 ( catMaybes , mapMaybe )
1922
@@ -35,18 +38,23 @@ import GHC.Tc.Types
3538 ( TcM )
3639import GHC.Tc.Types.Constraint
3740 ( isEmptyWC , CtEvidence (.. ), ctEvEvId )
41+ import GHC.Tc.Utils.Monad
42+ ( TcRef )
3843import GHC.Tc.Utils.TcType
39- ( MetaDetails (.. ), metaTyVarRef
40- , tyCoVarsOfTypeList
41- )
44+ ( MetaDetails (.. ), metaTyVarRef )
4245import GHC.Tc.Utils.TcMType
43- ( isUnfilledMetaTyVar , newTcEvBinds )
46+ ( newTcEvBinds , isUnfilledMetaTyVar )
4447
4548-- ghc-tcplugin-api
4649import GHC.TcPlugin.API
4750import GHC.TcPlugin.API.Internal
4851 ( unsafeLiftTcM )
4952
53+ -- transformers
54+ import Control.Monad.Trans.Class ( lift )
55+ import Control.Monad.Trans.Writer.CPS ( WriterT )
56+ import qualified Control.Monad.Trans.Writer.CPS as Writer
57+
5058-- if-instance
5159import IfSat.Plugin.Compat
5260 ( wrapTcS , getRestoreTcS )
@@ -151,9 +159,7 @@ solveWanted defs@( PluginDefs { orClass } ) givens wanted
151159
152160 -- Keep track of the current solver state in order to backtrack
153161 -- in the event that our attempt at solving 'ct_l' fails.
154- ct_l_unfilled_metas <- wrapTcS
155- $ filterM isUnfilledMetaTyVar
156- $ tyCoVarsOfTypeList ct_l_ty
162+ ct_l_unfilled <- wrapTcS $ unfilledRefsOfType ct_l_ty
157163 restoreTcS <- getRestoreTcS
158164
159165 -- Try to solve 'ct_l', using both Givens and top-level instances.
@@ -183,11 +189,8 @@ solveWanted defs@( PluginDefs { orClass } ) givens wanted
183189 -- Reset the solver state to before we attempted to solve 'ct_l',
184190 -- and undo any type variable unifications that happened.
185191 restoreTcS
186- wrapTcS $ for_ ct_l_unfilled_metas \ meta ->
187- writeTcRef ( metaTyVarRef meta ) Flexi
188- ct_r_unfilled_metas <- wrapTcS
189- $ filterM isUnfilledMetaTyVar
190- $ tyCoVarsOfTypeList ct_r_ty
192+ wrapTcS $ traverse_ unfillMutableRef ct_l_unfilled
193+ ct_r_unfilled <- wrapTcS $ unfilledRefsOfType ct_r_ty
191194
192195 -- Try to solve 'ct_r', using both Givens and top-level instances.
193196 residual_ct_r <- solveSimpleWanteds ( unitBag ct_r )
@@ -212,8 +215,7 @@ solveWanted defs@( PluginDefs { orClass } ) givens wanted
212215 -- Reset the solver state to before we attempted to solve 'ct_r',
213216 -- and undo any type variable unifications that happened.
214217 restoreTcS
215- wrapTcS $ for_ ct_r_unfilled_metas \ meta ->
216- writeTcRef ( metaTyVarRef meta ) Flexi
218+ wrapTcS $ traverse_ unfillMutableRef ct_r_unfilled
217219
218220 pure Nothing
219221 pure $ ( , wanted ) <$> mb_wanted_evTerm
@@ -306,6 +308,45 @@ dispatchFalseEvTerm defs@( PluginDefs { orClass } ) givens ct_l_ty ct_r_ty ct_r_
306308 )
307309 ]
308310
311+ -- | A mutable reference that was originally unfilled
312+ data UnfilledRef
313+ -- | A metavariable that was originally unfilled
314+ = UnfilledMeta ! ( TcRef MetaDetails )
315+ -- | A coercion hole that was originally unfilled
316+ | UnfilledCoHole ! ( TcRef ( Maybe Coercion ) )
317+
318+ -- | Gather all the unfilled mutable references of a type: unfilled
319+ -- metavariables and unfilled coercion holes.
320+ unfilledRefsOfType :: TcType -> TcM [UnfilledRef ]
321+ unfilledRefsOfType = Writer. execWriterT . go_ty
322+ where
323+ (go_ty, _go_tys, _go_co, _go_cos) =
324+ mapTyCo @ ( WriterT [UnfilledRef ] TcM ) $
325+ TyCoMapper
326+ { tcm_tyvar = \ _ tv -> do
327+ unfilled_meta <- lift $ isUnfilledMetaTyVar tv
328+ when unfilled_meta $
329+ Writer. tell $ [ UnfilledMeta $ metaTyVarRef tv ]
330+ return $ mkTyVarTy tv
331+ , tcm_tycobinder = \ _ tcv _ftf k -> k () tcv
332+ , tcm_tycon = return
333+ , tcm_covar = \ _ cv -> return $ mkCoVarCo cv
334+ , tcm_hole = \ _ hole@ (CoercionHole { ch_ref = hole_ref }) -> do
335+ hole_contents <- lift $ readTcRef hole_ref
336+ case hole_contents of
337+ Nothing ->
338+ Writer. tell $ [ UnfilledCoHole hole_ref ]
339+ Just {} ->
340+ return ()
341+ return $ mkHoleCo hole
342+ }
343+
344+ -- | Restore a mutable reference to the unfilled state.
345+ unfillMutableRef :: UnfilledRef -> TcM ()
346+ unfillMutableRef = \ case
347+ UnfilledMeta ref -> writeTcRef ref Flexi
348+ UnfilledCoHole hole -> writeTcRef hole Nothing
349+
309350-- The type @IsSat ct ~ b@.
310351sat_eqTy :: PluginDefs -> Type -> Bool -> Type
311352sat_eqTy ( PluginDefs { isSatTyCon } ) ct_ty booly
@@ -383,9 +424,7 @@ isSatRewriter ( PluginDefs { isSatTyCon } ) givens [ct_ty] = do
383424
384425 -- Keep track of the current solver state in order to undo any
385426 -- side-effects after calling 'solveSimpleWanteds' on 'ct'.
386- ct_unfilled_metas <- wrapTcS
387- $ filterM isUnfilledMetaTyVar
388- $ tyCoVarsOfTypeList ct_ty
427+ ct_unfilled <- wrapTcS $ unfilledRefsOfType ct_ty
389428 restoreTcS <- getRestoreTcS
390429
391430 -- Try to solve 'ct', using both Givens and top-level instances.
@@ -396,8 +435,7 @@ isSatRewriter ( PluginDefs { isSatTyCon } ) givens [ct_ty] = do
396435 -- Reset the solver state to before we attempted to solve 'ct',
397436 -- and undo any type variable unifications that happened.
398437 restoreTcS
399- wrapTcS $ for_ ct_unfilled_metas \ meta ->
400- writeTcRef ( metaTyVarRef meta ) Flexi
438+ wrapTcS $ traverse_ unfillMutableRef ct_unfilled
401439
402440 let
403441 is_sat :: Bool
0 commit comments