diff --git a/.github/workflows/main-pull.yml b/.github/workflows/main-pull.yml index adfa3f61d..2174e042f 100644 --- a/.github/workflows/main-pull.yml +++ b/.github/workflows/main-pull.yml @@ -99,6 +99,9 @@ jobs: - name: Benchmark compiler performance run: cabal bench bench-compiler-performance + - name: Benchmark experimental compiler performance + run: cabal bench experimental-backend + - name: Benchmark polynomial multiplication run: cabal bench bench-poly-mul diff --git a/symbolic-base/src/ZkFold/Algebra/Polynomial/Multivariate/Expression.hs b/symbolic-base/src/ZkFold/Algebra/Polynomial/Multivariate/Expression.hs new file mode 100644 index 000000000..f9fb97efb --- /dev/null +++ b/symbolic-base/src/ZkFold/Algebra/Polynomial/Multivariate/Expression.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE BlockArguments #-} + +module ZkFold.Algebra.Polynomial.Multivariate.Expression (Polynomial, evalPoly) where + +import Control.Applicative (Applicative (..)) +import Data.Foldable (Foldable) +import Data.Function ((.)) +import Data.Functor (Functor, (<$>)) +import Data.Traversable (Traversable) +import Numeric.Natural (Natural) + +import ZkFold.Algebra.Class + +data Polynomial a v + = PVar v + | PConst a + | Polynomial a v :+ Polynomial a v + | Polynomial a v :* Polynomial a v + deriving (Foldable, Functor, Traversable) + +evalPoly :: Algebra a b => Polynomial a v -> (v -> b) -> b +evalPoly (PVar i) x = x i +evalPoly (PConst c) _ = fromConstant c +evalPoly (p :+ q) x = evalPoly p x + evalPoly q x +evalPoly (p :* q) x = evalPoly p x * evalPoly q x + +instance Ring a => Applicative (Polynomial a) where + pure = PVar + fs <*> xs = evalPoly fs (<$> xs) + +instance Zero a => Zero (Polynomial a v) where + zero = PConst zero + +instance FromConstant c a => Scale c (Polynomial a v) + +instance AdditiveSemigroup (Polynomial a v) where + (+) = (:+) + +instance (FromConstant Natural a, Zero a) => AdditiveMonoid (Polynomial a v) + +instance Ring a => AdditiveGroup (Polynomial a v) where + negate = scale (negate one :: a) + +instance MultiplicativeSemigroup (Polynomial a v) where + (*) = (:*) + +instance MultiplicativeMonoid a => MultiplicativeMonoid (Polynomial a v) where + one = fromConstant (one :: a) + +instance FromConstant c a => FromConstant c (Polynomial a v) where + fromConstant = PConst . fromConstant + +instance MultiplicativeMonoid a => Exponent (Polynomial a v) Natural where + (^) = natPow + +instance {-# OVERLAPPING #-} FromConstant (Polynomial a v) (Polynomial a v) + +instance {-# OVERLAPPING #-} Scale (Polynomial a v) (Polynomial a v) + +instance Semiring a => Semiring (Polynomial a v) + +instance Ring a => Ring (Polynomial a v) diff --git a/symbolic-base/src/ZkFold/ArithmeticCircuit/Experimental.hs b/symbolic-base/src/ZkFold/ArithmeticCircuit/Elem.hs similarity index 96% rename from symbolic-base/src/ZkFold/ArithmeticCircuit/Experimental.hs rename to symbolic-base/src/ZkFold/ArithmeticCircuit/Elem.hs index 23decf1cb..3b0690972 100644 --- a/symbolic-base/src/ZkFold/ArithmeticCircuit/Experimental.hs +++ b/symbolic-base/src/ZkFold/ArithmeticCircuit/Elem.hs @@ -5,7 +5,7 @@ {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} -module ZkFold.ArithmeticCircuit.Experimental where +module ZkFold.ArithmeticCircuit.Elem where import Control.Applicative (pure) import Control.DeepSeq (NFData (..), NFData1, liftRnf, rwhnf) @@ -25,10 +25,9 @@ import Data.Ord (Ord (..)) import Data.Semigroup (Semigroup, (<>)) import Data.Semigroup.Generic (GenericSemigroupMonoid (..)) import qualified Data.Set as S -import Data.Traversable (traverse) +import Data.Traversable (Traversable, traverse) import Data.Tuple (swap, uncurry) import Data.Type.Equality (type (~)) -import Data.Typeable (Typeable) import GHC.Generics (Generic, Par1 (..), U1, (:*:) (..)) import Optics (zoom) import Prelude (error) @@ -95,10 +94,7 @@ newtype Polynomial a v = MkPolynomial --------------- Type-preserving lookup constraint representation --------------- -data LookupEntry v - = forall f. - (Functor f, Foldable f, NFData1 f, Typeable f) => - LEntry (f v) (LookupTable f) +data LookupEntry v = forall f. Traversable f => LEntry (f v) (LookupTable f) ------------- Box of constraints supporting efficient concatenation ------------ diff --git a/symbolic-base/src/ZkFold/ArithmeticCircuit/Node.hs b/symbolic-base/src/ZkFold/ArithmeticCircuit/Node.hs new file mode 100644 index 000000000..918dce3ed --- /dev/null +++ b/symbolic-base/src/ZkFold/ArithmeticCircuit/Node.hs @@ -0,0 +1,420 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +{- HLINT ignore "Use record patterns" -} + +module ZkFold.ArithmeticCircuit.Node where + +import Control.Applicative (pure, (<*>)) +import Control.DeepSeq (NFData (..), rwhnf) +import Control.Monad (unless, (>>=)) +import Control.Monad.IO.Class (liftIO) +import Control.Monad.Reader (ReaderT (runReaderT), asks) +import Control.Monad.State (MonadState, StateT, runState, runStateT, state) +import Data.Binary (Binary) +import Data.Function (const, flip, ($), (.)) +import Data.Functor (fmap, (<$>)) +import Data.HashTable.IO (BasicHashTable) +import qualified Data.HashTable.IO as T +import Data.Kind (Type) +import Data.Maybe (Maybe (..), isJust) +import Data.Monoid (Monoid (..)) +import Data.Semigroup (Semigroup (..)) +import Data.Traversable (Traversable, traverse) +import Data.Type.Equality (type (~)) +import GHC.Err (error) +import GHC.Generics (U1, (:*:) (..)) +import GHC.Integer (Integer) +import GHC.TypeNats (KnownNat) +import Numeric.Natural (Natural) +import System.IO (IO) +import System.IO.Unsafe (unsafePerformIO) +import System.Mem.StableName (StableName, makeStableName) + +import ZkFold.Algebra.Class +import ZkFold.Algebra.Number (Prime) +import ZkFold.Algebra.Polynomial.Multivariate.Expression (Polynomial, evalPoly) +import ZkFold.ArithmeticCircuit (ArithmeticCircuit, optimize, solder) +import ZkFold.ArithmeticCircuit.Context (CircuitContext, crown, emptyContext) +import ZkFold.ArithmeticCircuit.Op +import ZkFold.ArithmeticCircuit.Var (NewVar (..), Var) +import ZkFold.ArithmeticCircuit.Witness (BooleanF, EuclideanF, OrderingF, WitnessF) +import ZkFold.Control.Conditional (Conditional (..)) +import ZkFold.Data.Bool (BoolType (..)) +import ZkFold.Data.Eq (Eq (..)) +import ZkFold.Data.Ord (IsOrdering (..), Ord (..)) +import ZkFold.Symbolic.Class (Arithmetic) +import ZkFold.Symbolic.Compat (CompatContext (..)) +import qualified ZkFold.Symbolic.Compiler as Old +import qualified ZkFold.Symbolic.Data.Class as Old +import ZkFold.Symbolic.Data.V2 (HasRep, Layout, SymbolicData (fromLayout, toLayout)) +import ZkFold.Symbolic.MonadCircuit (at, constraint, lookupConstraint, unconstrained) +import ZkFold.Symbolic.V2 (Constraint (..), Symbolic (..)) + +------------------- Experimental single-output circuit type -------------------- + +-- | @Node p s@ is a node in 'Symbolic' computation graph +-- where @p@ is a size of a base field and @s@ is a 'Sort' of a node. +data Node p (s :: Sort) where + -- | An input node + NodeInput :: NewVar -> Node p ZZp + -- | An application of an operation from 'PrimeField' class. + NodeApply :: KnownSort s => Op (Node p) s -> Node p s + -- | An application of a 'constrain' function from 'Symbolic' class. + NodeConstrain :: Constraint (Node p ZZp) -> Node p ZZp -> Node p ZZp + +instance NFData (Node p s) where + rnf = rwhnf -- GADTs are strict, so no need to eval + +instance KnownSort s => Conditional (Node p BB) (Node p s) where + bool onFalse onTrue condition = NodeApply (OpBool onFalse onTrue condition) + +instance BoolType (Node p BB) where + true = one + false = zero + not = negate + (&&) = (*) + xor = (+) + x || y = NodeApply (OpOr x y) + +instance Semigroup (Node p OO) where + x <> y = NodeApply (OpAppend x y) + +instance Monoid (Node p OO) where + mempty = zero + +instance IsOrdering (Node p OO) where + lt = fromConstant ((-1) :: Integer) + eq = zero + gt = one + +instance Eq (Node p s) where + type BooleanOf (Node p s) = Node p BB + x == y = NodeApply (OpEq x y) + x /= y = NodeApply (OpNEq x y) + +instance Ord (Node p ZZ) where + type OrderingOf (Node p ZZ) = Node p OO + compare x y = NodeApply (OpCompare x y) + ordering x y z o = NodeApply (OpOrder x y z o) + x < y = compare x y == lt + x <= y = compare x y /= gt + x >= y = compare x y /= lt + x > y = compare x y == gt + +instance + (KnownNat p, KnownNat (NumberOfBits (Node p ZZp))) + => Finite (Node p ZZp) + where + type Order (Node p ZZp) = p + +instance KnownSort s => FromConstant Natural (Node p s) where + fromConstant x = NodeApply $ OpConst (fromConstant x) + +instance KnownSort s => FromConstant Integer (Node p s) where + fromConstant x = NodeApply (OpConst x) + +instance FromConstant (Node p ZZ) (Node p ZZp) where + fromConstant x = NodeApply (OpFrom x) + +instance KnownSort s => Scale Natural (Node p s) where + scale k = NodeApply . OpScale (fromConstant k) + +instance KnownSort s => Scale Integer (Node p s) where + scale k x = NodeApply (OpScale k x) + +instance KnownSort s => Exponent (Node p s) Natural where + x ^ p = NodeApply (OpExp x p) + +instance Prime p => Exponent (Node p ZZp) Integer where + (^) = intPowF + +instance KnownSort s => Zero (Node p s) where + zero = fromConstant (0 :: Integer) + +instance KnownSort s => AdditiveSemigroup (Node p s) where + x + y = NodeApply (OpAdd x y) + +instance KnownSort s => AdditiveMonoid (Node p s) + +instance KnownSort s => AdditiveGroup (Node p s) where + negate = NodeApply . OpNeg + +instance KnownSort s => MultiplicativeSemigroup (Node p s) where + x * y = NodeApply (OpMul x y) + +instance KnownSort s => MultiplicativeMonoid (Node p s) where + one = fromConstant (1 :: Integer) + +instance KnownSort s => Semiring (Node p s) + +instance KnownSort s => Ring (Node p s) + +instance SemiEuclidean (Node p ZZ) where + div x y = NodeApply (OpDiv x y) + mod x y = NodeApply (OpMod x y) + +instance Euclidean (Node p ZZ) where + gcd x y = NodeApply (OpGcd x y) + bezoutL x y = NodeApply (OpBezoutL x y) + bezoutR x y = NodeApply (OpBezoutR x y) + +instance Prime p => Field (Node p ZZp) where + finv = NodeApply . OpInv + +instance (Prime p, KnownNat (NumberOfBits (Node p ZZp))) => PrimeField (Node p ZZp) where + type IntegralOf (Node p ZZp) = Node p ZZ + toIntegral = NodeApply . OpTo + +instance (Prime p, KnownNat (NumberOfBits (Node p ZZp))) => Symbolic (Node p ZZp) where + constrain = NodeConstrain + +------------------------- Optimized compilation function ----------------------- + +type family InputF (f :: Type) where + InputF (i a -> f) = i :*: Input f + InputF (o a) = U1 + +type family OutputF (f :: Type) where + OutputF (i a -> f) = Output f + OutputF (o a) = o + +class + ( SymbolicData (Input f) + , HasRep (Input f) a + , SymbolicData (Output f) + , Traversable (Layout (Output f) a) + ) => + SymbolicFunction (a :: Type) (f :: Type) + | f -> a + where + type Input f :: Type -> Type + type Input f = InputF f + type Output f :: Type -> Type + type Output f = OutputF f + symApply :: f -> Input f a -> Output f a + +instance + (SymbolicData o, Traversable (Layout o a), Input (o a) ~ U1, Output (o a) ~ o) + => SymbolicFunction a (o a) + where + symApply = const + +instance + (SymbolicData i, HasRep i a, SymbolicFunction a f) + => SymbolicFunction a (i a -> f) + where + symApply f (x :*: y) = symApply (f x) y + +compileV1 + :: forall a f n d + . ( Arithmetic a + , Binary a + , Old.SymbolicFunction f + , Order a ~ n + , Old.Context f ~ CompatContext (Node n ZZp) + , Old.Domain f ~ d + ) + => f + -> ArithmeticCircuit + a + (Old.Layout d n :*: Old.Payload d n) + (Old.Layout (Old.Range f) n) +compileV1 = + optimize . solder . \f (l :*: p) -> + let (output, circuit) = unsafePerformIO do + compiler <- makeCompiler + flip runStateT emptyContext + . flip runReaderT compiler + . traverse compileNode + . compatContext + . Old.arithmetize + . Old.apply f + $ Old.restore (CompatContext (NodeInput <$> l), NodeInput <$> p) + in crown circuit (toVar <$> output) + +compileV2 + :: forall a c f + . (Arithmetic a, Binary a, c ~ Node (Order a) ZZp, SymbolicFunction c f) + => f -> ArithmeticCircuit a (Layout (Input f) c) (Layout (Output f) c) +compileV2 = + optimize . solder . \(f :: f) (l :: Layout (Input f) c NewVar) -> + let (output, circuit) = unsafePerformIO do + compiler <- makeCompiler + flip runStateT emptyContext + . flip runReaderT compiler + . traverse compileNode + . toLayout + . symApply f + $ fromLayout (fmap NodeInput l) + in crown circuit (toVar <$> output) + +------------------------- Compilation internals -------------------------------- + +type StableTable k v = BasicHashTable (StableName k) v + +data Compiler a = Compiler + { constraintLog :: StableTable (Constraint (Node (Order a) ZZp)) () + , witnessExtractor :: WitnessExtractor a + } + +type CompilerM a = ReaderT (Compiler a) (StateT (CircuitContext a U1) IO) + +makeCompiler :: IO (Compiler a) +makeCompiler = Compiler <$> T.new <*> makeExtractor + +compileNode + :: forall a s + . (Arithmetic a, Binary a) + => Node (Order a) s -> CompilerM a (Witness a s) +compileNode (NodeInput v) = pure $ FieldVar (pure v) +compileNode (NodeConstrain !c n) = do + snc <- liftIO (makeStableName c) + isDone <- + asks constraintLog + >>= liftIO . \log -> T.mutate log snc \x -> + (Just (), isJust x) + unless isDone case c of + Lookup lkp ns -> do + vs <- traverse (fmap toVar . compileNode) ns + state $ runState (lookupConstraint vs lkp) + Polynomial p -> do + poly <- traverse (fmap toVar . compileNode) p + state . runState $ constraint (evalPoly @a poly) + compileNode n +compileNode (NodeApply !op) = do + sno <- liftIO (makeStableName op) + asks witnessExtractor >>= liftIO . request sno >>= \case + Just w -> pure w + Nothing -> do + w <- traverseOp compileNode op >>= opToWitness + asks witnessExtractor >>= liftIO . insertWitness sno w + pure w + +data WitnessExtractor a = WitnessExtractor + { weVars :: StableTable (Op (Node (Order a)) ZZp) (Var a) + , weInts :: StableTable (Op (Node (Order a)) ZZ) (EuclideanF a NewVar) + , weBool :: StableTable (Op (Node (Order a)) BB) (BooleanF a NewVar) + , weOrds :: StableTable (Op (Node (Order a)) OO) (OrderingF a NewVar) + } + +makeExtractor :: IO (WitnessExtractor a) +makeExtractor = WitnessExtractor <$> T.new <*> T.new <*> T.new <*> T.new + +insertWitness + :: StableName (Op (Node (Order a)) s) + -> Witness a s + -> WitnessExtractor a + -> IO () +insertWitness sn witness WitnessExtractor {..} = case witness of + FieldVar v -> T.insert weVars sn v + IntWitness w -> T.insert weInts sn w + BoolWitness w -> T.insert weBool sn w + OrdWitness w -> T.insert weOrds sn w + +request + :: forall a s + . KnownSort s + => StableName (Op (Node (Order a)) s) + -> WitnessExtractor a + -> IO (Maybe (Witness a s)) +request sn WitnessExtractor {..} = case knownSort @s of + ZZpSing -> fmap FieldVar <$> T.lookup weVars sn + ZZSing -> fmap IntWitness <$> T.lookup weInts sn + BBSing -> fmap BoolWitness <$> T.lookup weBool sn + OOSing -> fmap OrdWitness <$> T.lookup weOrds sn + +data Witness a (s :: Sort) where + FieldVar :: Var a -> Witness a ZZp + IntWitness :: EuclideanF a NewVar -> Witness a ZZ + BoolWitness :: BooleanF a NewVar -> Witness a BB + OrdWitness :: OrderingF a NewVar -> Witness a OO + +toVar :: Witness a ZZp -> Var a +toVar (FieldVar v) = v + +instance (KnownSort s, FromConstant Integer a) => FromConstant Integer (Witness a s) where + fromConstant = case knownSort @s of + ZZpSing -> FieldVar . fromConstant + ZZSing -> IntWitness . fromConstant + BBSing -> BoolWitness . fromConstant + OOSing -> OrdWitness . fromConstant + +instance Scale Integer a => Scale Integer (Witness a s) where + scale k = \case + FieldVar v -> FieldVar (scale k v) + IntWitness w -> IntWitness (scale k w) + BoolWitness w -> BoolWitness (fromConstant k && w) + OrdWitness w -> OrdWitness (fromConstant k <> w) -- TODO: wrong but unused + +instance PrimeField a => Eq (Witness a s) where + type BooleanOf (Witness a s) = BooleanF a NewVar + FieldVar u == FieldVar v = at @_ @(WitnessF a NewVar) u == at v + IntWitness v == IntWitness w = v == w + BoolWitness v == BoolWitness w = not (xor v w) + OrdWitness _ == OrdWitness _ = error "not implemented" + x /= y = not (x == y) + +opToWitness + :: forall a s m + . (Arithmetic a, Binary a, MonadState (CircuitContext a U1) m) + => Op (Witness a) s -> m (Witness a s) +opToWitness = \case + OpConst c -> pure (fromConstant c) + OpScale k w -> pure (scale k w) + OpAdd (FieldVar u) (FieldVar v) -> + state . runState $ FieldVar <$> unconstrained (at u + at v) + OpAdd (IntWitness v) (IntWitness w) -> pure $ IntWitness (v + w) + OpAdd (BoolWitness v) (BoolWitness w) -> pure $ BoolWitness (v `xor` w) + OpAdd (OrdWitness v) (OrdWitness w) -> + pure $ OrdWitness (v <> w) -- TODO wrong but unused + OpMul (FieldVar u) (FieldVar v) -> + state . runState $ FieldVar <$> unconstrained (at u * at v) + OpMul (IntWitness v) (IntWitness w) -> pure $ IntWitness (v * w) + OpMul (BoolWitness v) (BoolWitness w) -> pure $ BoolWitness (v && w) + OpMul (OrdWitness v) (OrdWitness w) -> + pure $ OrdWitness (v <> w) -- TODO wrong but unused + OpNeg w -> pure (scale (-1 :: Integer) w) + OpExp (FieldVar v) p -> + state . runState $ FieldVar <$> unconstrained (at v ^ p) + OpExp (IntWitness w) p -> pure $ IntWitness (w ^ p) + OpExp (BoolWitness w) p -> pure $ BoolWitness (fromConstant (p == 0) || w) + OpExp (OrdWitness w) _ -> pure (OrdWitness w) -- TODO wrong but unused + OpFrom (IntWitness w) -> + state . runState $ FieldVar <$> unconstrained (fromConstant w) + OpTo (FieldVar v) -> + pure $ IntWitness $ toIntegral @(WitnessF a NewVar) (at v) + OpCompare (IntWitness v) (IntWitness w) -> + pure $ OrdWitness (v `compare` w) + OpDiv (IntWitness v) (IntWitness w) -> pure $ IntWitness (v `div` w) + OpMod (IntWitness v) (IntWitness w) -> pure $ IntWitness (v `mod` w) + OpGcd (IntWitness v) (IntWitness w) -> pure $ IntWitness (v `gcd` w) + OpBezoutL (IntWitness v) (IntWitness w) -> + pure $ IntWitness (v `bezoutL` w) + OpBezoutR (IntWitness v) (IntWitness w) -> + pure $ IntWitness (v `bezoutR` w) + OpInv (FieldVar v) -> + state . runState $ FieldVar <$> unconstrained (at v) + OpEq x y -> pure $ BoolWitness (x == y) + OpNEq x y -> pure $ BoolWitness (x /= y) + OpOr (BoolWitness v) (BoolWitness w) -> pure $ BoolWitness (v || w) + OpBool (FieldVar u) (FieldVar v) (BoolWitness w) -> + state . runState $ FieldVar <$> unconstrained (bool (at u) (at v) w) + OpBool (IntWitness v) (IntWitness w) (BoolWitness b) -> + pure $ IntWitness (bool v w b) + OpBool (BoolWitness v) (BoolWitness w) (BoolWitness b) -> + pure $ BoolWitness (bool v w b) + OpBool (OrdWitness _) (OrdWitness _) (BoolWitness _) -> + pure $ OrdWitness (error "not implemented") + OpAppend (OrdWitness v) (OrdWitness w) -> pure $ OrdWitness (v <> w) + OpOrder (IntWitness u) (IntWitness v) (IntWitness w) (OrdWitness o) -> + pure $ IntWitness (ordering u v w o) + +instance {-# OVERLAPPING #-} Ring a => Scale v (Polynomial a v) + +instance {-# OVERLAPPING #-} Ring a => FromConstant v (Polynomial a v) where + fromConstant = pure diff --git a/symbolic-base/src/ZkFold/ArithmeticCircuit/Op.hs b/symbolic-base/src/ZkFold/ArithmeticCircuit/Op.hs new file mode 100644 index 000000000..2f3f103ef --- /dev/null +++ b/symbolic-base/src/ZkFold/ArithmeticCircuit/Op.hs @@ -0,0 +1,86 @@ +module ZkFold.ArithmeticCircuit.Op where + +import Control.Applicative (Applicative, pure, (<*>)) +import Data.Functor ((<$>)) +import GHC.Integer (Integer) +import GHC.Natural (Natural) + +-- | A 'PrimeField' class describes operations available between 4 types: +-- * finite field itself +-- * backing integral type +-- * booleans associated with both +-- * ordering associated with integers +-- +-- 'Sort' is a set of labels for differentiating between them. +data Sort = ZZp | ZZ | BB | OO + +-- | A way to store type-level 'Sort' on the term-level. +data SortSing (s :: Sort) where + ZZpSing :: SortSing ZZp + ZZSing :: SortSing ZZ + BBSing :: SortSing BB + OOSing :: SortSing OO + +-- | A class for communicating between term-level and type-level 'Sort's. +class KnownSort (s :: Sort) where + knownSort :: SortSing s + +instance KnownSort ZZp where + knownSort = ZZpSing + +instance KnownSort ZZ where + knownSort = ZZSing + +instance KnownSort BB where + knownSort = BBSing + +instance KnownSort OO where + knownSort = OOSing + +-- | 'Op f s' describes operations available in the 'PrimeField' class +-- where 's' is a sort of the result of an operation +-- and 'f' is a @Sort -> Type@ functor which, given a sort, +-- would return an argument to the operation of the type labeled with this sort. +data Op f (s :: Sort) where + OpConst :: KnownSort s => Integer -> Op f s + OpScale :: Integer -> f s -> Op f s + OpAdd, OpMul :: f s -> f s -> Op f s + OpNeg :: f s -> Op f s + OpExp :: f s -> Natural -> Op f s + OpFrom :: f ZZ -> Op f ZZp + OpTo :: f ZZp -> Op f ZZ + OpCompare :: f ZZ -> f ZZ -> Op f OO + OpDiv, OpMod, OpGcd, OpBezoutL, OpBezoutR :: f ZZ -> f ZZ -> Op f ZZ + OpInv :: f ZZp -> Op f ZZp + OpEq, OpNEq :: f s -> f s -> Op f BB + OpOr :: f BB -> f BB -> Op f BB + OpBool :: f s -> f s -> f BB -> Op f s + OpAppend :: f OO -> f OO -> Op f OO + OpOrder :: f ZZ -> f ZZ -> f ZZ -> f OO -> Op f ZZ + +-- | Replacement of a @Sort -> Type@ functor in 'Op', +-- possibly with side-effects. +traverseOp + :: Applicative m => (forall t. f t -> m (g t)) -> Op f s -> m (Op g s) +traverseOp f = \case + OpConst i -> pure (OpConst i) + OpScale i x -> OpScale i <$> f x + OpAdd x y -> OpAdd <$> f x <*> f y + OpMul x y -> OpMul <$> f x <*> f y + OpNeg x -> OpNeg <$> f x + OpExp x e -> (`OpExp` e) <$> f x + OpFrom x -> OpFrom <$> f x + OpTo x -> OpTo <$> f x + OpCompare x y -> OpCompare <$> f x <*> f y + OpDiv x y -> OpDiv <$> f x <*> f y + OpMod x y -> OpMod <$> f x <*> f y + OpGcd x y -> OpGcd <$> f x <*> f y + OpBezoutL x y -> OpBezoutL <$> f x <*> f y + OpBezoutR x y -> OpBezoutR <$> f x <*> f y + OpInv x -> OpInv <$> f x + OpEq x y -> OpEq <$> f x <*> f y + OpNEq x y -> OpNEq <$> f x <*> f y + OpOr x y -> OpOr <$> f x <*> f y + OpBool x y z -> OpBool <$> f x <*> f y <*> f z + OpAppend x y -> OpAppend <$> f x <*> f y + OpOrder x y z w -> OpOrder <$> f x <*> f y <*> f z <*> f w diff --git a/symbolic-base/src/ZkFold/ArithmeticCircuit/Witness.hs b/symbolic-base/src/ZkFold/ArithmeticCircuit/Witness.hs index 7f558a26a..23d939263 100644 --- a/symbolic-base/src/ZkFold/ArithmeticCircuit/Witness.hs +++ b/symbolic-base/src/ZkFold/ArithmeticCircuit/Witness.hs @@ -3,11 +3,14 @@ module ZkFold.ArithmeticCircuit.Witness where import Control.Applicative (Applicative (..)) import Control.DeepSeq (NFData (..), rwhnf) import Control.Monad (Monad (..), ap) +import Data.Bool (Bool) import Data.Function (const, (.)) import Data.Functor (Functor) import Data.Monoid (Monoid (..)) +import Data.Ord (Ordering (..)) import Data.Semigroup (Semigroup (..)) import GHC.Integer (Integer) +import GHC.Real (odd) import Numeric.Natural (Natural) import ZkFold.Algebra.Class @@ -103,6 +106,10 @@ instance BoolType (BooleanF a v) where BooleanF f || BooleanF g = BooleanF (\x -> f x || g x) BooleanF f `xor` BooleanF g = BooleanF (\x -> f x `xor` g x) +instance FromConstant Bool (BooleanF a v) where fromConstant = bool false true + +instance FromConstant Integer (BooleanF a v) where fromConstant x = fromConstant (odd x) + newtype EuclideanF a v = EuclideanF {euclideanF :: forall w. IsWitness a w => (v -> w) -> IntegralOf w} instance FromConstant Natural (EuclideanF a v) where fromConstant x = EuclideanF (fromConstant x) @@ -174,3 +181,18 @@ instance IsOrdering (OrderingF a v) where lt = OrderingF (const lt) eq = OrderingF (const eq) gt = OrderingF (const gt) + +instance FromConstant Ordering (OrderingF a v) where + fromConstant = \case + LT -> lt + EQ -> eq + GT -> gt + +intToOrdering :: Integer -> Ordering +intToOrdering x = case x `mod` 3 of + 0 -> EQ + 1 -> GT + _ -> LT + +instance FromConstant Integer (OrderingF a v) where + fromConstant = fromConstant . intToOrdering diff --git a/symbolic-base/src/ZkFold/Symbolic/MonadCircuit.hs b/symbolic-base/src/ZkFold/Symbolic/MonadCircuit.hs index b74f67d4c..8a463293f 100644 --- a/symbolic-base/src/ZkFold/Symbolic/MonadCircuit.hs +++ b/symbolic-base/src/ZkFold/Symbolic/MonadCircuit.hs @@ -1,12 +1,9 @@ module ZkFold.Symbolic.MonadCircuit where -import Control.DeepSeq (NFData1) import Control.Monad (Monad (return)) -import Data.Foldable (Foldable) import Data.Function (($), (.)) -import Data.Functor (Functor) import Data.Set (singleton) -import Data.Typeable (Typeable) +import Data.Traversable (Traversable) import GHC.Generics (Par1 (..)) import Numeric.Natural (Natural) @@ -77,11 +74,7 @@ class -- For examples of lookup constraints, see 'rangeConstraint'. -- -- NOTE: currently, provided constraints are directly fed to zkSNARK in use. - lookupConstraint - :: (NFData1 f, Foldable f, Functor f, Typeable f) - => f var - -> LookupTable f - -> m () + lookupConstraint :: Traversable f => f var -> LookupTable f -> m () -- | Creates new variable given a polynomial witness -- AND adds a corresponding polynomial constraint. diff --git a/symbolic-base/src/ZkFold/Symbolic/V2.hs b/symbolic-base/src/ZkFold/Symbolic/V2.hs index 01617a6bb..e6ffdf081 100644 --- a/symbolic-base/src/ZkFold/Symbolic/V2.hs +++ b/symbolic-base/src/ZkFold/Symbolic/V2.hs @@ -7,14 +7,15 @@ import Data.Binary (Binary) import Data.Foldable (Foldable) import Data.Functor.Rep (Rep, Representable) import Data.Set (Set) +import Data.Traversable (Traversable) import GHC.Generics (Par1, (:*:)) import Numeric.Natural (Natural) -import ZkFold.Algebra.Class (Algebra, PrimeField) +import ZkFold.Algebra.Class (Algebra, PrimeField, (-)) import ZkFold.Data.FromList (FromList) --- | @LookupTable a f@ is a type of compact lookup table descriptions using ideas from relational algebra. --- @a@ is a base field type, @f@ is a functor such that @f a@ is a type whose subset this lookup table describes. +-- | @LookupTable f@ is a type of compact @f@-ary lookup table descriptions +-- using ideas from relational algebra. data LookupTable f where -- | @Ranges@ describes a set of disjoint segments of the base field. Ranges :: Set (Natural, Natural) -> LookupTable Par1 @@ -27,9 +28,14 @@ data LookupTable f where -> LookupTable f -> LookupTable (f :*: g) +type Poly a = forall b. Algebra a b => b + data Constraint a where - Polynomial :: (forall b. Algebra a b => b) -> Constraint a - Lookup :: LookupTable f -> f a -> Constraint a + Polynomial :: Poly a -> Constraint a + Lookup :: Traversable f => LookupTable f -> f a -> Constraint a + +(=!=) :: Poly a -> Poly a -> Constraint a +p =!= q = Polynomial (p - q) -- TODO: Get rid of NFData constraint class (NFData c, PrimeField c) => Symbolic c where diff --git a/symbolic-base/symbolic-base.cabal b/symbolic-base/symbolic-base.cabal index c9791265a..7445b350d 100644 --- a/symbolic-base/symbolic-base.cabal +++ b/symbolic-base/symbolic-base.cabal @@ -130,6 +130,7 @@ library ZkFold.Algebra.EllipticCurve.PlutoEris ZkFold.Algebra.EllipticCurve.Secp256k1 ZkFold.Algebra.Polynomial.Multivariate + ZkFold.Algebra.Polynomial.Multivariate.Expression ZkFold.Algebra.Polynomial.Multivariate.Groebner ZkFold.Algebra.Polynomial.Multivariate.Monomial ZkFold.Algebra.Polynomial.Multivariate.Internal @@ -145,8 +146,10 @@ library ZkFold.ArithmeticCircuit.Children ZkFold.ArithmeticCircuit.Context ZkFold.ArithmeticCircuit.Desugaring - ZkFold.ArithmeticCircuit.Experimental + ZkFold.ArithmeticCircuit.Elem ZkFold.ArithmeticCircuit.MerkleHash + ZkFold.ArithmeticCircuit.Node + ZkFold.ArithmeticCircuit.Op ZkFold.ArithmeticCircuit.Optimization ZkFold.ArithmeticCircuit.Var ZkFold.ArithmeticCircuit.Witness @@ -319,6 +322,7 @@ library deepseq <= 1.6.0.0, distributive , generic-random < 1.6, + hashtables , infinite-list , lens , monoidal-containers , @@ -357,6 +361,7 @@ library deepseq <= 1.6.0.0, distributive , generic-random < 1.6, + hashtables , infinite-list , lens , monoidal-containers , @@ -376,7 +381,7 @@ library vector < 0.14, vector-binary-instances < 0.3, vector-split < 1.1, - openapi3 , + openapi3 , hs-source-dirs: src if arch(wasm32) ghc-options: diff --git a/symbolic-examples/bench/Experimental.hs b/symbolic-examples/bench/Experimental.hs index 3658fb5cc..69419a0c4 100644 --- a/symbolic-examples/bench/Experimental.hs +++ b/symbolic-examples/bench/Experimental.hs @@ -1,5 +1,6 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE TypeOperators #-} module Main where @@ -12,6 +13,7 @@ import Data.Functor.Rep (tabulate) import Data.Semigroup ((<>)) import Data.String (String) import Data.String qualified as String +import Data.Type.Equality (type (~)) import System.IO (IO) import Test.Tasty (testGroup) import Test.Tasty.Bench @@ -22,8 +24,10 @@ import ZkFold.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) import ZkFold.Algebra.Field (Zp) import ZkFold.ArithmeticCircuit (ArithmeticCircuit, eval) import ZkFold.ArithmeticCircuit qualified as Circuit -import ZkFold.ArithmeticCircuit.Experimental (AC, compile) +import ZkFold.ArithmeticCircuit.Elem (compile) +import ZkFold.ArithmeticCircuit.Node (compileV1) import ZkFold.Data.Binary (Binary, toByteString) +import ZkFold.Symbolic.Class (BaseField, Symbolic) import ZkFold.Symbolic.Data.Combinators (RegisterSize (Auto)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) import ZkFold.Symbolic.Data.UInt (UInt) @@ -50,39 +54,54 @@ fromBinary = foldr ((+) . fromConstant . toInteger) zero . toByteString type A = Zp BLS12_381_Scalar -type C = AC A +fib100 :: Symbolic c => FieldElement c -> FieldElement c +fib100 = exampleFibonacciMod 100 -expMod :: UInt 32 Auto C -> UInt 16 Auto C -> UInt 64 Auto C -> UInt 64 Auto C +expMod + :: (Symbolic c, BaseField c ~ A) + => UInt 32 Auto c -> UInt 16 Auto c -> UInt 64 Auto c -> UInt 64 Auto c expMod = exampleUIntExpMod -fib100 :: FieldElement C -> FieldElement C -fib100 = exampleFibonacciMod 100 - main :: IO () main = defaultMain [ testGroup "MiMCHash" - [ bench "compilation" $ nf (compile @A) exampleMiMC + [ bench "compilation (Node)" $ nf (compileV1 @A) exampleMiMC + , bench "compilation (Elem)" $ nf (compile @A) exampleMiMC + , env (return $ force $ compileV1 @A exampleMiMC) $ + bench "evaluation (Node)" . nf (`eval` tabulate zero) , env (return $ force $ compile @A exampleMiMC) $ - bench "evaluation" . nf (`eval` tabulate zero) - , goldenVsString "golden stats" "stats/Experimental.MiMC" do - return $ metrics "Experimental.MiMC" (compile @A exampleMiMC) + bench "evaluation (Elem)" . nf (`eval` tabulate zero) + , goldenVsString "golden stats (Node)" "stats/Experimental.MiMC.Node" do + return $ metrics "Experimental.MiMC.Node" (compileV1 @A exampleMiMC) + , goldenVsString "golden stats (Elem)" "stats/Experimental.MiMC.Elem" do + return $ metrics "Experimental.MiMC.Elem" (compile @A exampleMiMC) ] , testGroup "Fib100" - [ bench "compilation" $ nf compile fib100 - , env (return $ force $ compile fib100) $ - bench "evaluation" . nf (`eval` tabulate fromBinary) - , goldenVsString "golden stats" "stats/Experimental.Fib100" do - return $ metrics "Experimental.Fib100" (compile fib100) + [ bench "compilation (Node)" $ nf (compileV1 @A) fib100 + , bench "compilation (Elem)" $ nf (compile @A) fib100 + , env (return $ force $ compileV1 @A fib100) $ + bench "evaluation (Node)" . nf (`eval` tabulate fromBinary) + , env (return $ force $ compile @A fib100) $ + bench "evaluation (Elem)" . nf (`eval` tabulate fromBinary) + , goldenVsString "golden stats (Node)" "stats/Experimental.Fib100.Node" do + return $ metrics "Experimental.Fib100.Node" (compileV1 @A fib100) + , goldenVsString "golden stats (Elem)" "stats/Experimental.Fib100.Elem" do + return $ metrics "Experimental.Fib100.Elem" (compile @A fib100) ] , testGroup "ExpMod" - [ bench "compilation" $ nf compile expMod - , env (return $ force $ compile expMod) $ - bench "evaluation" . nf (`eval` tabulate fromBinary) - , goldenVsString "golden stats" "stats/Experimental.ExpMod" do - return $ metrics "Experimental.ExpMod" (compile expMod) + [ bench "compilation (Node)" $ nf (compileV1 @A) expMod + , bench "compilation (Elem)" $ nf (compile @A) expMod + , env (return $ force $ compileV1 @A expMod) $ + bench "evaluation (Node)" . nf (`eval` tabulate fromBinary) + , env (return $ force $ compile @A expMod) $ + bench "evaluation (Elem)" . nf (`eval` tabulate fromBinary) + , goldenVsString "golden stats (Node)" "stats/Experimental.ExpMod.Node" do + return $ metrics "Experimental.ExpMod.Node" (compileV1 @A expMod) + , goldenVsString "golden stats (Elem)" "stats/Experimental.ExpMod.Elem" do + return $ metrics "Experimental.ExpMod.Elem" (compile @A expMod) ] ] diff --git a/symbolic-examples/stats/Experimental.ExpMod b/symbolic-examples/stats/Experimental.ExpMod.Elem similarity index 83% rename from symbolic-examples/stats/Experimental.ExpMod rename to symbolic-examples/stats/Experimental.ExpMod.Elem index 85cf69e53..960611869 100644 --- a/symbolic-examples/stats/Experimental.ExpMod +++ b/symbolic-examples/stats/Experimental.ExpMod.Elem @@ -1,4 +1,4 @@ -Experimental.ExpMod +Experimental.ExpMod.Elem Number of polynomial constraints: 6236 Number of variables: 8195 Number of lookup constraints: 3366 diff --git a/symbolic-examples/stats/Experimental.ExpMod.Node b/symbolic-examples/stats/Experimental.ExpMod.Node new file mode 100644 index 000000000..5adb0b20a --- /dev/null +++ b/symbolic-examples/stats/Experimental.ExpMod.Node @@ -0,0 +1,5 @@ +Experimental.ExpMod.Node +Number of polynomial constraints: 6167 +Number of variables: 8987 +Number of lookup constraints: 3266 +Number of lookup tables: 2 \ No newline at end of file diff --git a/symbolic-examples/stats/Experimental.Fib100 b/symbolic-examples/stats/Experimental.Fib100.Elem similarity index 83% rename from symbolic-examples/stats/Experimental.Fib100 rename to symbolic-examples/stats/Experimental.Fib100.Elem index 9c223e2b9..a1ccab67f 100644 --- a/symbolic-examples/stats/Experimental.Fib100 +++ b/symbolic-examples/stats/Experimental.Fib100.Elem @@ -1,4 +1,4 @@ -Experimental.Fib100 +Experimental.Fib100.Elem Number of polynomial constraints: 1093 Number of variables: 1194 Number of lookup constraints: 0 diff --git a/symbolic-examples/stats/Experimental.Fib100.Node b/symbolic-examples/stats/Experimental.Fib100.Node new file mode 100644 index 000000000..746abbd16 --- /dev/null +++ b/symbolic-examples/stats/Experimental.Fib100.Node @@ -0,0 +1,5 @@ +Experimental.Fib100.Node +Number of polynomial constraints: 993 +Number of variables: 297 +Number of lookup constraints: 0 +Number of lookup tables: 0 \ No newline at end of file diff --git a/symbolic-examples/stats/Experimental.MiMC b/symbolic-examples/stats/Experimental.MiMC.Elem similarity index 84% rename from symbolic-examples/stats/Experimental.MiMC rename to symbolic-examples/stats/Experimental.MiMC.Elem index 083fbbfbc..33b1a0746 100644 --- a/symbolic-examples/stats/Experimental.MiMC +++ b/symbolic-examples/stats/Experimental.MiMC.Elem @@ -1,4 +1,4 @@ -Experimental.MiMC +Experimental.MiMC.Elem Number of polynomial constraints: 1542 Number of variables: 1761 Number of lookup constraints: 0 diff --git a/symbolic-examples/stats/Experimental.MiMC.Node b/symbolic-examples/stats/Experimental.MiMC.Node new file mode 100644 index 000000000..742ea085e --- /dev/null +++ b/symbolic-examples/stats/Experimental.MiMC.Node @@ -0,0 +1,5 @@ +Experimental.MiMC.Node +Number of polynomial constraints: 1538 +Number of variables: 880 +Number of lookup constraints: 0 +Number of lookup tables: 0 \ No newline at end of file