diff --git a/.gitignore b/.gitignore index 46d8b9a..1a6706d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ pom.xml.asc /.lein-* /.nrepl-port *~ +/.cljs_rhino_repl \ No newline at end of file diff --git a/src/active/clojure/record.cljc b/src/active/clojure/record.cljc index 1e63a1d..652f62d 100644 --- a/src/active/clojure/record.cljc +++ b/src/active/clojure/record.cljc @@ -9,6 +9,13 @@ (when-not (instance? type rec) (throw (js/Error. (str "Wrong record type passed to accessor." rec type)))))) + +(defrecord RecordMeta + ;; unresolved store for record related symbols. May not leak outside this + ;; namespace. Contains ns to allow post-macro qualification; see `record-meta` function. + [predicate constructor ordered-accessors ns]) + + #?(:clj (defmacro define-record-type "Attach doc properties to the type and the field names to get reasonable docstrings." @@ -75,7 +82,21 @@ ?field-names (map first ?field-triples) reference (fn [name] (str "[[" (ns-name *ns*) "/" name "]]")) - ?docref (str "See " (reference ?constructor) ".")] + ?docref (str "See " (reference ?constructor) ".") + + ;; we need to internalize symbols to ns resolve them + _ (intern *ns* ?predicate) + _ (intern *ns* ?constructor) + + record-meta (->RecordMeta + (resolve ?predicate) (resolve ?constructor) + (mapv (fn [constr] + (let [accessor (second (first (filter #(= (first %) constr) ?field-triples)))] + (intern *ns* accessor) + (resolve accessor))) + ?constructor-args) + *ns*)] + (let [?field-names-set (set ?field-names)] (doseq [?constructor-arg ?constructor-args] (when-not (contains? ?field-names-set ?constructor-arg) @@ -85,21 +106,23 @@ (defrecord ~?type [~@(map first ?field-triples)] ~@?opt+specs) - (def ~(document-with-arglist ?predicate '[thing] (str "Is object a `" ?type "` record? " ?docref)) + + (def ~(vary-meta (document-with-arglist ?predicate '[thing] (str "Is object a `" ?type "` record? " ?docref)) + assoc :meta record-meta) (fn [x#] (instance? ~?type x#))) (def ~(document-with-arglist ?constructor - (vec ?constructor-args) - (str "Construct a `" ?type "`" - (name-doc ?type) - " record.\n" - (apply str - (map (fn [[?field ?accessor ?lens]] - (str "\n`" ?field "`" (name-doc ?field) ": access via " (reference ?accessor) - (if ?lens - (str ", lens " (reference ?lens)) - ""))) - ?field-triples)))) + (vec ?constructor-args) + (str "Construct a `" ?type "`" + (name-doc ?type) + " record.\n" + (apply str + (map (fn [[?field ?accessor ?lens]] + (str "\n`" ?field "`" (name-doc ?field) ": access via " (reference ?accessor) + (if ?lens + (str ", lens " (reference ?lens)) + ""))) + ?field-triples)))) (fn [~@?constructor-args] (new ~?type ~@(map (fn [[?field _]] @@ -132,3 +155,13 @@ ?field-triples))))))) '())))) ?field-triples))))) + + +(defn predicate->record-meta [predicate] + ;; Expects a namespace resolved predicate + ;; if the predicate meta contains UnresolvedRecordMeta it returns a RecordMeta + ;; record with resolved values. Else nil. + (:meta (meta predicate))) + +(defn record-type-predicate? [foo] + (instance? RecordMeta (predicate->record-meta foo))) diff --git a/src/active/clojure/sum_type.clj b/src/active/clojure/sum_type.clj new file mode 100644 index 0000000..997fd68 --- /dev/null +++ b/src/active/clojure/sum_type.clj @@ -0,0 +1,304 @@ +(ns active.clojure.sum-type + (:require [active.clojure.record :as record])) + +(record/define-record-type SumTypeMeta + ;; A SumTypeMeta represents a node of a tree structure. + ;; Types represent children of the node, again being a SumTypeMeta or a record/RecordMeta. + (make-sum-type-meta predicate types) sum-type-meta? + [predicate sum-type-meta-predicate + types sum-type-meta-types]) + +(record/define-record-type RecordMeta + ;; Record meta record-type. Since in record definition no record-types + ;; are available we translate it. + (make-record-meta predicate constructor ordered-accessors) record-meta? + [predicate record-meta-predicate + constructor record-meta-constructor + ordered-accessors record-meta-ordered-accessors]) + + +(defn predicate->sum-type-meta [predicate] (:meta (meta predicate))) +(defn sum-type-predicate? [type-or-predicate] + (sum-type-meta? (predicate->sum-type-meta type-or-predicate))) + +;; a clause is one of the following: +;; - ClauseWithPredicate, describing a matching clause based on a prediate +;; - DefaultClause, describing a matching clause based on the special form :default +;; - ClauseWithExtraction, describing a matching clause based on a constructor + +(record/define-record-type CaluseWithPredicate + (make-clause-with-predicate predicate clause) clause-with-predicate? + [predicate clause-with-predicate-predicate + clause clause-with-predicate-clause]) + +(record/define-record-type DefaultClause + (make-default-clause clause) default-clause? + [clause default-clause-clause]) + +(record/define-record-type ClauseWithExtraction + (make-clause-with-extraction + predicate constructor-call clause ordered-accessors) clause-with-extraction? + [predicate clause-with-extraction-predicate + constructor-call clause-with-extraction-constructor-call + clause clause-with-extraction-clause + ordered-accessors clause-with-extraction-ordered-accessors]) + + + +(defn- has-predicate? [predicate type-tree] + (cond + (sum-type-meta? type-tree) + (or + (= predicate (sum-type-meta-predicate type-tree)) + (some identity (map #(has-predicate? predicate %) (sum-type-meta-types type-tree)))) + + (record-meta? type-tree) + (= (record-meta-predicate type-tree) predicate))) + +(defn collect-leafs [type-tree] + ;; collects all leafs from a type-tree, that is, all record-types + (cond + (record-meta? type-tree) [type-tree] + (sum-type-meta? type-tree) (mapcat collect-leafs (sum-type-meta-types type-tree)))) + + +(defn remove-predicate [type-tree predicate] + ;; removes all types having the given predicate as their predicate from the type-tree + (cond + (record-meta? type-tree) + (when-not (= predicate (record-meta-predicate type-tree)) + type-tree) + + (sum-type-meta? type-tree) + (when-not (= predicate (sum-type-meta-predicate type-tree)) + (make-sum-type-meta + (sum-type-meta-predicate type-tree) + (remove nil? (map #(remove-predicate % predicate) + (sum-type-meta-types type-tree))))))) + + +(defn check-predicates! [type-tree predicates has-default] + + ;; checks if all predicates fullfil the following requirements: + ;; + ;; 1) A predicate is corresponds to one of the types in the tree. The only exception is :default + ;; 2) A predicate is reachable. A predicate is defined unreachable if another + ;; predicate matches its value before it is executed. + ;; 3) A matching is exhaustive. That is, that every input-value has at least one + ;; predicate in the type tree that evaluates to true. If default is given this check is + ;; skipped. + ;; + ;; To fullfil the last two requirements types are removed iteratively by their predicate, in + ;; `predicates` given order. They idea is that if a sum-type node is removed, all its children + ;; are removed from the tree, too. + ;; + ;; Before each deletion it is checked if the predicate is present in the tree (requirement 2) + ;; After that, the corresponding type is removed. If leafs exist after all predicates are handled + ;; the matching is not exhaustive. + ;; + ;; Throws an IllegalArgumentException if one the requirements fails. + + (let [missing-predicates + (-> (reduce (fn [working-type-tree predicate] + + ;; requirement 1 + (when-not (or (= :default predicate)(has-predicate? predicate type-tree)) + (throw (IllegalArgumentException. (str "The following predicate is not of any type in " + (sum-type-meta-predicate type-tree) + ": " predicate)))) + + ;; requirement 2 + ;; TODO: Make this a warning? How does a warning look like? + (when-not (or (= :default predicate)(has-predicate? predicate working-type-tree)) + (throw (IllegalArgumentException. (str "The following predicate will never be reached: " + predicate)))) + + (remove-predicate working-type-tree predicate)) + type-tree predicates) + (collect-leafs))] + + ;; requirement 3 + (when-not (or has-default (empty? missing-predicates)) + (throw (Exception. (str "Non exhaustive match would fail on the following type(s): " + (mapv record-meta-predicate missing-predicates))))))) + + +(defn- ->record-meta [untyped] + (make-record-meta + (:predicate untyped) + (:constructor untyped) + (:ordered-accessors untyped))) + + +(defmacro define-sum-type [type-name predicate predicates] + ;; TODO doc + (let [qualified-predicates (doall (map #(ns-resolve *ns* %) predicates)) + _ (intern *ns* predicate)] + (when-not (every? identity + (map #(or (record/record-type-predicate? %) (sum-type-predicate? %)) + qualified-predicates)) + (throw (IllegalArgumentException. (str "Predicates of active.clojure.record or active.clojure.sum-type " + "required, found: " (pr-str (map record/predicate->record-meta qualified-predicates)))))) + + `(do + (defn ~(vary-meta predicate + assoc :meta + (make-sum-type-meta + (resolve predicate) + (map #(cond + (record/record-type-predicate? %) + (->record-meta (record/predicate->record-meta %)) + + (sum-type-predicate? %) + (predicate->sum-type-meta %)) + qualified-predicates))) + [arg#] + (boolean (some identity (map (fn [pred#] (pred# arg#)) ~predicates)))) + ))) + + +(defn- find-meta-by-constructor [constructor type-tree] + ;; searches for meta in type-tree by constructor + ;; returns nil if not present, else a RecordMeta + (cond + (sum-type-meta? type-tree) + (some identity (map #(find-meta-by-constructor constructor %) + (sum-type-meta-types type-tree))) + + (record-meta? type-tree) + (when (= (record-meta-constructor type-tree) constructor) + type-tree))) + +(defn parse-clause-with-extraction [condition clause sum-type-meta] + ;; makes a clause-with-extraction for constructor based matching + ;; throws if the number of matching arguments differs from constructor definition + ;; throws if not every constructor argument is a symbol (`_` are ignored). + ;; throws if constructor is not found in sum-type + (if-let [meta (find-meta-by-constructor (first condition) sum-type-meta)] + (let [predicate (record-meta-predicate meta) + ordered-accessors (record-meta-ordered-accessors meta)] + + ;; check correct number of arguments in constructor + (when-not (= (count ordered-accessors) (count (rest condition))) + (throw (IllegalArgumentException. + (str (first condition) " requires " (count ordered-accessors) " arguments: " condition)))) + + ;; check if every constructor argument is a symbol + (when-not (every? symbol? (rest condition)) + (throw (IllegalArgumentException. + (str "Every argument in constructor matching must be a symbol: " condition)))) + + (make-clause-with-extraction predicate condition clause ordered-accessors)) + + ;; if constructor not found in meta + (throw (IllegalArgumentException. (str "Constructor not found in sum-type types: " (first condition)))))) + + +(defn- parse-clause-with-predicate [predicate clause sum-type-meta] + ;; makes a clause-with-predicate for predicate based matching + ;; throws if predicate is not present in sum-type + (cond + (= :default predicate) + (make-default-clause clause) + + (has-predicate? predicate sum-type-meta) + (make-clause-with-predicate predicate clause) + + :else + (throw (IllegalArgumentException. + (str "Predicate " predicate " not found in sum-type type: " (:predicate sum-type-meta)))))) + +(defn parse-clauses [paired-clauses sum-type-meta] + ;; parses all clauses-forms to one of the following: + ;; clause-with-extraction for constructor based matching + ;; clause-with-predicate for predicate based matching + ;; default for :default special form + (mapv + (fn [[condition clause]] + (if (list? condition) + ;; (list (ns-resolve *ns* (first condition)) (rest condition)) + (parse-clause-with-extraction (cons (ns-resolve *ns* (first condition)) (rest condition)) + clause sum-type-meta) + (parse-clause-with-predicate (if (= :default condition) + :default + (ns-resolve *ns* condition)) + clause sum-type-meta))) + + paired-clauses)) + + +(defn clause->predicate [clause] + ;; extracts the predicate for a given clause + (cond + (clause-with-extraction? clause) + (clause-with-extraction-predicate clause) + + (clause-with-predicate? clause) + (clause-with-predicate-predicate clause) + + (default-clause? clause) + :default)) + + + +(defn- expand-clause-with-extraction-forms [clause arg] + ;; creates forms for constructor based cond case + (let [predicate (clause-with-extraction-predicate clause) + constructor-args (rest (clause-with-extraction-constructor-call clause)) + implied-clause (clause-with-extraction-clause clause) + ordered-accessors (clause-with-extraction-ordered-accessors clause) + args-to-accessors (->> (map vector constructor-args ordered-accessors) + ;; filter _ + (filter (fn [[ binding _]] (not= '_ binding))) + (mapcat (fn [[binding accessor]] + [binding (list accessor arg)])))] + [(list predicate arg) (list 'let (vec args-to-accessors) implied-clause)])) + +(defn- expand-clause-with-predicate-forms [clause arg] + ;; creates forms for predicate cond case + (list + (list (clause-with-predicate-predicate clause) arg) + (clause-with-predicate-clause clause))) + +(defn- expand-default-clause [clause] + ;; Creates forms for default cond case + (list :default (default-clause-clause clause))) + +(defn expand-clause-forms [clause arg] + ;; creates form for clause depending on type + (cond + (clause-with-extraction? clause) + (expand-clause-with-extraction-forms clause arg) + + (clause-with-predicate? clause) + (expand-clause-with-predicate-forms clause arg) + + (default-clause? clause) + (expand-default-clause clause))) + + +(defn has-default? [parsed-clauses] + ;; checks if clauses contains default clause and if it is the last clause + ;; throws if position of default (last) is violated + ;; returns true if default is found iff in last position, false else + (if (some default-clause? parsed-clauses) + (do + (if-not (default-clause? (last parsed-clauses)) + (throw (IllegalArgumentException. "Default clause only allowed as last clause")) + true)) + false)) + +(defmacro match [sum-type-predicate arg & clauses] + (let [sum-type-meta (predicate->sum-type-meta (ns-resolve *ns* sum-type-predicate)) + paired-clauses (partition 2 clauses) + parsed-clauses (parse-clauses paired-clauses sum-type-meta) + default (has-default? parsed-clauses) + predicates (mapv clause->predicate parsed-clauses) + expanded-clauses-forms (mapcat #(expand-clause-forms % arg) parsed-clauses)] + + (do + (check-predicates! sum-type-meta predicates default) + `(do + (when-not (~sum-type-predicate ~arg) + (throw (IllegalArgumentException. (str "Matching argument not of type " ~sum-type-predicate ": " ~arg)))) + ~(cons 'cond expanded-clauses-forms))))) diff --git a/test/active/clojure/sum_type_data_test.clj b/test/active/clojure/sum_type_data_test.clj new file mode 100644 index 0000000..df89b1b --- /dev/null +++ b/test/active/clojure/sum_type_data_test.clj @@ -0,0 +1,19 @@ +(ns active.clojure.sum-type-data-test + (:require [active.clojure.record :refer (define-record-type)] + [active.clojure.sum-type :refer (define-sum-type)])) + + +(define-record-type Circle + (make-circle radius) circle? + [radius circle-radius]) + +(define-record-type Square + (make-square height width) square? + [height square-height + width square-width]) + + +(define-sum-type Forms forms? [circle? square?]) + + + diff --git a/test/active/clojure/sum_type_test.clj b/test/active/clojure/sum_type_test.clj new file mode 100644 index 0000000..797525d --- /dev/null +++ b/test/active/clojure/sum_type_test.clj @@ -0,0 +1,158 @@ +(ns active.clojure.sum-type-test + (:require [active.clojure.record :refer (define-record-type)] + [active.clojure.sum-type :refer (define-sum-type match)] + [clojure.test :refer :all] + [active.clojure.sum-type-data-test :as data])) + +(define-record-type Red + (make-red saturation) red? + [saturation red-saturation]) + +(define-record-type Green + (make-green saturation) green? + [saturation green-saturation]) + +(define-record-type Blue + (make-blue saturation) blue? + [saturation blue-saturation]) + + +(def red (make-red 0.4)) +(def green (make-green 1.0)) +(def blue (make-blue 0.2)) + +(define-record-type UltraViolet + (make-ultra-violet wave-length) ultra-violet? + [wave-length ultra-violet-wave-length]) + +(define-record-type InfraRed + (make-infra-red wave-length) infra-red? + [wave-length infra-red-wave-length]) + +(def ultra-violet (make-ultra-violet 42)) +(def infra-red (make-infra-red 43)) + +(define-sum-type RGBColor + rgb-color? + [red? green? blue?]) + +(define-sum-type Invisible + invisible? + [ultra-violet? infra-red?]) + +(define-sum-type All + all? + [rgb-color? invisible?]) + +(deftest sum-type-predicate + (is (rgb-color? red)) + (is (rgb-color? green)) + (is (rgb-color? blue)) + + (is (not (invisible? red))) + (is (not (invisible? green))) + (is (not (invisible? blue))) + + (is (invisible? infra-red)) + (is (invisible? ultra-violet)) + + (is (not (rgb-color? infra-red))) + (is (not (rgb-color? ultra-violet))) + + (is (all? red)) + (is (all? green)) + (is (all? blue)) + (is (all? infra-red)) + (is (all? ultra-violet))) + +(defn match-rgb [rgb-color] + (match rgb-color? rgb-color + red? (red-saturation rgb-color) + green? (green-saturation rgb-color) + blue? (blue-saturation rgb-color))) + +(deftest working-matching + (is (= (match-rgb red) 0.4)) + (is (= (match-rgb green) 1.0)) + (is (= (match-rgb blue) 0.2))) + + +(deftest working-matching-default + (letfn [(foo [rgb-color] + (match rgb-color? rgb-color + red? "Hello" + :default "Bye!"))] + (is (= (foo red) "Hello")) + (is (= (foo blue) "Bye!")))) + +(deftest wrong-argument-type + (is (thrown? Throwable (match-rgb 2)))) + + +(deftest combined-sum-type + + (letfn [(foo [color] + (match all? color + red? "Visible" + blue? "Visible" + green? "Visible" + infra-red? "Invisible" + ultra-violet? "Invisible"))] + (is (= (foo red) "Visible")) + (is (= (foo blue) "Visible")) + (is (= (foo infra-red) "Invisible"))) + + (letfn [(foo [color] + (match all? color + red? "Red" + blue? "Blue" + green? "Green" + invisible? "Opaque"))] + (is (= (foo red) "Red")) + (is (= (foo blue) "Blue")) + (is (= (foo infra-red) "Opaque")))) + + +(deftest extractor-tests + (letfn [(desaturate [color] + (match all? color + (make-red s) (make-red (/ s 2)) + (make-green s) (make-green (/ s 2)) + (make-blue s) (make-blue (/ s 2)) + invisible? color))] + (is (= (make-red 6) (desaturate (make-red 12)))) + (is (= (make-green 6)(desaturate (make-green 12)))) + (is (= (make-blue 6) (desaturate (make-blue 12)))) + (is (= (make-ultra-violet 123) (desaturate (make-ultra-violet 123)))))) + + +(deftest nested-extractor + (letfn [(crazy [color] + (match all? color + + (make-ultra-violet containing-color) + (match rgb-color? containing-color + (make-red a) (str "You got red with " a) + (make-blue b) (str "You got blue with " b) + (make-green _) (str "Oh, this is green!")) + + all? + "It wasn rgb in invisible disguise :("))] + + (is (= "You got red with green" (crazy (make-ultra-violet (make-red "green"))))) + (is (= "You got blue with green" (crazy (make-ultra-violet (make-blue "green"))))) + (is (= "Oh, this is green!" (crazy (make-ultra-violet (make-green "green"))))) + (is (= "It wasn rgb in invisible disguise :(" (crazy (make-infra-red 123)))))) + +(define-sum-type FormsAndColors forms&colors? [data/circle? data/square? rgb-color?]) + +(deftest from-other-ns + (letfn [(form-or-color [foc] + (match forms&colors? foc + data/circle? "It's a circle!" + (data/make-square a b) (str "It's a square with " a " and " b) + rgb-color? "It's a color!"))] + + (is (= "It's a circle!" (form-or-color (data/make-circle 12)))) + (is (= "It's a square with 12 and 42" (form-or-color (data/make-square 12 42)))) + (is (= "It's a color!" (form-or-color (make-red 42))))))