diff --git a/README.md b/README.md index 280a115..fae8348 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![build](https://github.com/bgavran/TypeSafe_Tensors/actions/workflows/build.yml/badge.svg)](https://github.com/bgavran/TypeSafe_Tensors/actions/workflows/build.yml) -> TLDR; numpy, but with types, first-class axes, and tensors over structured data +> TLDR; NumPy reimagined with dependent types: native support for trees, braching and interaction TensorType is a framework for pure functional tensor processing, implemented in Idris 2. It * is **type-safe**: tensor shapes, indexing and contractions are checked at compile time @@ -35,13 +35,30 @@ import Data.Tensor Now you can construct tensors directly: ```idris -t0 : Tensor ["j" ~~> 3, "k" ~~> 4] Double -t0 = ># [ [0, 1, 2, 3] - , [4, 5, 6, 7] - , [8, 9, 10, 11]] +myVector : Tensor ["i" ~~> 5] Double +myVector = ># [1,2,3,4,5] ``` -This declares the type of a `3 x 4` matrix with axes named "j" and "k", and uses `>#` to populate it with values. `>#` behaves like a constructor: it takes a concrete value and turns it into the tensor of the appropriate shape (it should be visually read as a 'map' (`>`) into a 'tensor' (`#`)). +This is a vector of size `5`, with its axis named "i". + +```idris +myMatrix : Tensor ["j" ~~> 3, "k" ~~> 4] Double +myMatrix = ># [ [0, 1, 2, 3] + , [4, 5, 6, 7] + , [8, 9, 10, 11]] +``` + +This is a matrix with dimensions `3 x 4` with its axes named "j" and "k". In both examples them `>#` is used to populate the tensor with concrete values. + +If you load these up in the REPL (`pack repl examples/BasicExamples.idr`), you can print them: +```idris +BasicExamples> :exec printLn myVector +[1.0 2.0 3.0 4.0 5.0] +BasicExamples> :exec printLn myMatrix +[[ 0.0 1.0 2.0 3.0] + [ 4.0 5.0 6.0 7.0] + [ 8.0 9.0 10.0 11.0]] +``` You can use functions analogous to NumPy's, such as `np.arange` and `np.reshape`: @@ -53,7 +70,16 @@ t2 : Tensor ["i" ~~> 2, "j" ~~> 3] Double t2 = reshape t1 ``` -where the difference from NumPy is that these operations are typechecked - meaning they will fail _at compile-time_ if you supply an array with the wrong shape. +These do what you might expect if you're familiar with NumPy: +```idris +BasicExamples> :exec printLn t1 +[0.0 1.0 2.0 3.0 4.0 5.0] +BasicExamples> :exec printLn t2 +[[0.0 1.0 2.0] + [3.0 4.0 5.0]] +``` + +The difference from NumPy is that these operations are typechecked - meaning they will fail _at compile-time_ if you supply an array with the wrong shape. ```idris failing failConcrete : Tensor ["j" ~~> 3, "k" ~~> 4] Double @@ -67,7 +93,7 @@ failing They will also fail if you inconsistently bind axis names, for instance if you bind the same name to two different sizes: -``` +```idris failing failBinding : Tensor ["j" ~~> 3, "j" ~~> 4] Double failBinding = ># [ [0, 1, 2, 3] @@ -79,10 +105,10 @@ You can perform all sorts of familiar numeric operations: ```idris exampleSum : Tensor ["j" ~~> 3, "k" ~~> 4] Double -exampleSum = t0 + t0 +exampleSum = myMatrix + myMatrix exampleOp : Tensor ["j" ~~> 3, "k" ~~> 4] Double -exampleOp = abs (- (t0 * t0) <&> (+7)) +exampleOp = abs (- (myMatrix * myMatrix) <&> (+7)) ``` including standard linear algebra: @@ -92,38 +118,46 @@ dotExample : Tensor [] Double dotExample = dot t1 (t1 <&> (+5)) matMulExample : Tensor ["i" ~~> 2, "k" ~~> 4] Double -matMulExample = matMul t2 t0 +matMulExample = matMul t2 myMatrix transposeExample : Tensor ["k" ~~> 4, "j" ~~> 3] Double -transposeExample = transposeMatrix t0 +transposeExample = transposeMatrix myMatrix +``` + +```idris +BasicExamples> :exec printLn dotExample +130.0 +BasicExamples> :exec printLn matMulExample +[[20.0 23.0 26.0 29.0] + [56.0 68.0 80.0 92.0]] ``` -which all have their types checked at compile-time. For instance, you can't add tensors of different shapes, perform matrix multiplication if the dimensions of matrices don't match, or do any of these if you mislabel an axis. +All of these types are checked before you run your program. For instance, you can't add tensors of different shapes, perform matrix multiplication if the dimensions of matrices don't match, or do any of these if you mislabel an axis. ```idris failing sumFail : Tensor ["j" ~~> 3, "k" ~~> 4] Double - sumFail = t0 + t1 + sumFail = myMatrix + t1 failing matMulFail : Tensor ["i" ~~> 7] Double - matMulFail = matMul t0 t1 + matMulFail = matMul myMatrix t1 ``` Like in NumPy, you can safely index into tensors, set values of tensors, and perform slicing: ```idris -||| Retrieves the value of t0 at location [1, 2] +||| Retrieves the value of `myMatrix` at location [1, 2] indexExample : Double -indexExample = t0 @@ [1, 2] +indexExample = myMatrix @@ [1, 2] -||| Sets the value of t0 at location [1, 3] to 99 +||| Sets the value of `myMatrix` at location [1, 3] to 99 setExample : Tensor ["j" ~~> 3, "k" ~~> 4] Double -setExample = set t0 [1, 3] 99 +setExample = set myMatrix [1, 3] 99 ||| Takes the first two rows, and 1st column of t0 sliceExample : Tensor ["j" ~~> 2, "k" ~~> 1] Double -sliceExample = take [2, 1] t0 +sliceExample = take [2, 1] myMatrix ``` which will all fail if you go out of bounds: @@ -134,17 +168,17 @@ failing failing sliceFail : Tensor ["j" ~~> 10, "k" ~~> 2] Double - sliceFail = take [10, 2] t0 + sliceFail = take [10, 2] myMatrix ``` -**All of the above also works with non-rectangular tensors.** These are tensors whose shape is a tree, an inductive type, and even a continuation, rather than a rectangular grid. -That is, instead of binding an axis name to a number, we bind it to something called a "container", by using `~>` instead of `~~>`. -As a matter of fact, `~~>` behind the scenes desugars to `~>`, and we have been using this all along. -Let's see `t0` in this new form: +**All of the above also works with non-rectangular tensors.** These are tensors whose shape is not a list of numbers, but a list of *containers*. A container can model a tree, an inductive type, a continuation, and all sorts of different things. +As a matter of fact, our previous syntax for defining axes (`~~>`) was behind the scenes desugaring to a container-based one (which is `~>`). +We have been using it all along. +Let's see `myMatrix` in this new form: ```idris -t0Again : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double -t0Again = t0 +myMatrixAgain : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double +myMatrixAgain = myMatrix ``` Here `Vect` does not refer to `Vect` from `Data.Vect`, but rather the `Vect` container implemented [here](https://github.com/bgavran/TensorType/blob/main/src/Data/Container/Base/Object/Instances.idr#L68). @@ -159,30 +193,40 @@ t1again = arange The real power of container tensors comes from using other containers in the place of `Vect`. Here is a container `BinTree` of binary trees recast as a tree-tensor: ```idris -{- - 60 - / \ - 7 2 - / \ -(-42) 46 --} treeExample1 : Tensor ["myTree" ~> BinTree] Double treeExample1 = ># Node 60 (Node 7 (Leaf (-42)) (Leaf 46)) (Leaf 2) ``` -Unlike `Vect`, this container has a branching shape rather than a linear one. +We can print it out in the same way, and observe its branching shape: +```idris +BasicExamples> :exec printLn treeExample1 +60.0 +│ +├─ 7.0 +│ │ +│ ├─ -42.0 +│ │ +│ └─ 46.0 +│ +└─ 2.0 +``` + Here is another tree-tensor with a different shape: ```idris -{- - 5 - / \ -100 4 --} treeExample2 : Tensor ["myTree" ~> BinTree] Double treeExample2 = ># Node 5 (Leaf 100) (Leaf 4) ``` +```idris +BasicExamples> :exec printLn treeExample2 +5.0 +│ +├─ 100.0 +│ +└─ 4.0 +``` + Perhaps surprisingly, all linear algebra operations follow smoothly. The example below is the _dot product of trees_. The fact that these trees don't have the same number of elements is irrelevant; what matters is that the container defining them (`BinTree`) is the same. ```idris @@ -190,35 +234,52 @@ dotProductTree : Tensor [] Double dotProductTree = dot treeExample1 treeExample2 ``` +```idris +BasicExamples> :exec printLn dotProductTree +1408.0 +``` + We can do much more. Here's a tree-tensor with values only on its leaves: ```idris -{- - * - / \ - * 2 - / \ -(-42) 46 --} treeLeafExample : Tensor ["myTree" ~> BinTreeLeaf] Double treeLeafExample = ># Node' (Node' (Leaf (-42)) (Leaf 46)) (Leaf 2) ``` +```idris +BasicExamples> :exec printLn treeLeafExample +· +│ +├─ · +│ │ +│ ├─ -42.0 +│ │ +│ └─ 46.0 +│ +└─ 2.0 +``` + and here's a tree-tensor with values only on its nodes: ```idris -{- - 60 - / \ - 7 * - / \ - * * --} treeNodeExample : Tensor ["myTree" ~> BinTreeNode] Double treeNodeExample = ># Node 60 (Node 7 Leaf' Leaf') Leaf' ``` -This can get complex and nested, as `treeExample3` and `treeExample4` show. But it is still fully type-checked and works as you'd expect. +```idris +BasicExamples> :exec printLn treeNodeExample +60.0 +│ +├─ 7.0 +│ │ +│ ├─ · +│ │ +│ └─ · +│ +└─ · +``` + +This can get complex and nested, as `treeExample3` and `treeExample4` show. ```idris treeExample3 : Tensor ["myTree" ~> BinTreeNode, "j" ~> Vect 2] Double @@ -235,15 +296,41 @@ treeExample4 = ># (Node (Leaf [178, -43, 63]) Leaf' Leaf') ``` + +```idris +BasicExamples> :exec printLn treeExample4 +╔════════════════╗ +║· ║ +║│ ║ +║├─ [1.0 2.0 3.0]║ +║│ ║ +║└─ [4.0 5.0 6.0]║ +╚════════════════╝ +│ +├─ · +│ +└─ ╔═══════════════════╗ + ║[178.0 -43.0 63.0]║ + ╚═══════════════════╝ + │ + ├─ · + │ + └─ · +``` + +But all of this is still fully type-checked and works as you'd expect. For instance, we can index into `treeExample1`: ```idris {- -We can index into any of these structures - 60 - / \ - 7 2 <---- indexing here is okay - / \ -(-42) 46 +60.0 +│ +├─ 7.0 +│ │ +│ ├─ -42.0 +│ │ +│ └─ 46.0 +│ +└─ 2.0 <---- indexing here is okay -} indexTreeExample1 : Double indexTreeExample1 = treeExample1 @@ [GoRight AtLeaf] @@ -254,11 +341,17 @@ This will fail _at compile-time_ if you try to index outside of the tree structu ```idris failing {- - 60 - / \ - 7 2 - / \ \ - (-42) 46 X <---- indexing here throws an error + 60.0 + │ + ├─ 7.0 + │ │ + │ ├─ -42.0 + │ │ + │ └─ 46.0 + │ + └─ 2.0 + │ + └─ X <---- indexing here throws an error -} indexTreeExample1Fail : Double indexTreeExample1Fail = treeExample1 @@ [GoRight (GoRight AtLeaf)] @@ -268,17 +361,15 @@ Likewise, you can perform reshapes, views, reversals, sorting and traversals of Here is the in-order traversal of `treeExample1` from above. ```idris -{- - 60 - / \ - 7 2 - / \ -(-42) 46 --} traversalExample : Tensor ["myList" ~> List] Double traversalExample = restructure (wrapIntoVector inorder) treeExample1 ``` +```idris +BasicExamples> :exec printLn traversalExample +[-42.0, 7.0, 46.0, 60.0, 2.0] +``` + All of these can be used to define novel network architectures, see [src/Architectures](https://github.com/bgavran/TensorType/tree/main/src/NN/Architectures) for examples. ## Installation instructions @@ -291,14 +382,13 @@ It is recommended to manage the installation of this package (and generally, Idr 3. That's it! **To use TensorType in your project:** -1. Run `pack query tensortype` in the command-line to check whether your pack database is synced. If you don't see `tensortype` printed as output, you may need to run `pack update-db` first. -2. Add `tensortype` to the `depends` argument in your project's `.ipkg` file. (See `examples/tensortype-examples.ipkg` for an example) -3. Include `import Data.Tensor` at the top of your source files. -4. That's it! +1. Add `tensortype` to the `depends` argument in your project's `.ipkg` file. (See `examples/tensortype-examples.ipkg` for an example) +2. Include `import Data.Tensor` at the top of your source files. +3. That's it! ## Aim of TensorType -Attempts to bring deep learning to statically typed languages have struggled with expressiveness and ergonomics, typically only replicating what exists without imagining what can be. TensorType aims to do both: provide a practical, type-safe tensor library, and enable fundamentally new capabilities. Specifically: +Attempts to bring deep learning to statically typed languages have struggled with expressiveness and ergonomics, typically only replicating what exists without imagining what can be. TensorType aims to do both: provide a practical, type-safe tensor library, and enable fundamentally new capabilities. Specifically, the aim is to: > Enable type-driven development of structured neural network architectures. @@ -314,7 +404,7 @@ This especially holds for non-rectangular tensors, which are at the moment only TensorType's implementation hinges on three interdependent components: -* **Containers** for **well-typed indexing of non-cubical tensors**: they allow us to validate that an index into a generalised tensor is not out of bounds at compile-time. Doing this with cubical containers is easy since they expose the size information at the type level (i.e. `Tensor ["i" ~> Vect 2] Double`), but once we move on to the world of applicative functors this is no longer the case. Checking that an index into a `Tensor ["b" ~> BinTreeNode] Double` is not out of bounds is only possible if the underlying functor additionally comes equipped with the data of the valid set of "shapes" and the valid "positions" for that shape. This is equivalent to asking that the functor is polynomial, or that the functor is an extension of a container. +* **Containers** for **well-typed indexing of non-cubical tensors**: they allow us to validate that an index into a generalised tensor is not out of bounds at compile-time. Doing this with cubical containers is easy since they expose the size information at the type level (i.e. `Tensor ["i" ~> Vect 2] Double`), but once we move on to the world of applicative functors this is no longer the case. Checking that an index into a `Tensor ["b" ~> BinTreeNode] Double` is not out of bounds is only possible if the underlying functor additionally comes equipped with the data of the valid set of "shapes" and the valid "positions" for that shape. This is equivalent to asking that the functor is polynomial, or equivalently that the functor is the extension of some container. * **Applicative functors** for **generalised linear algebra**: they allow us to perform generalised linear algebra operations as described in the [Applicative Programming with Naperian Functors](https://www.cs.ox.ac.uk/people/jeremy.gibbons/publications/aplicative.pdf) paper. * **Dependent lenses** for **reshaping and traversing operations**: they allow us to define morphisms of containers, and therefore generalised tensor reshaping operations that do not operate on the content of the data, only the shape. These include views, reshapes, and traversals, and many other operations that appear in libraries like NumPy. diff --git a/examples/Attention.idr b/examples/Attention.idr index 0d39215..f352ed4 100644 --- a/examples/Attention.idr +++ b/examples/Attention.idr @@ -9,50 +9,60 @@ Attention example Will run self attention as usual, on matrices, and then on trees -------------------------------------------------------------------------------} +||| We start by dealing with ordinary matrices, and define the matrix axes + +||| The length of the input sequence +SeqLen : Axis +SeqLen = "seqLen" ~~> 3 + +||| The number of tokens for each element in the sequence +NumTokens : Axis +NumTokens = "numTokens" ~~> 4 + ||| We'll first instantiate self attention as a parametric map on matrices -SelfAttentionMat : {n, d : Nat} -> - {default False causalMask : Bool} -> - Tensor ["seqLen" ~~> n, "numTokens" ~~> d] Double -\-> - Tensor ["seqLen" ~~> n, "numTokens" ~~> d] Double +SelfAttentionMat : {default False causalMask : Bool} -> + Tensor [SeqLen, NumTokens] Double -\-> + Tensor [SeqLen, NumTokens] Double SelfAttentionMat {causalMask} = case causalMask of False => SelfAttention softargmaxImpl True => SelfAttention {causalMask=Attention.causalMask} softargmaxImpl ||| Let's fix a simple input matrix -inputMatrix : Tensor ["seqLen" ~~> 3, "numTokens" ~~> 2] Double -inputMatrix = ># [ [1, 3] - , [2, -3] - , [0, 0.3]] +inputMatrix : Tensor [SeqLen, NumTokens] Double +inputMatrix = ># [ [1, 3, 3, 2] + , [2, -3, 2, 1] + , [0, 0.3, 10, 9]] ||| Let's fix attention parameters for the query, key and value matrices. ||| For instance, a matrix of ones, a triangular matrix, and a matrix of threes -params : {d : Nat} -> SelfAttentionParams ("numTokens" ~~> d) {a=Double} +params : SelfAttentionParams NumTokens {a=Double} params = MkSAParams ones tri (ones <&> (*3)) ||| Now we can run self attention on the input matrix -||| This value can be inspected in REPL, or otherwise -outputMatrix : Tensor ["seqLen" ~~> 3, "numTokens" ~~> 2] Double +||| This value can be inspected in REPL via `:exec printLn outputMatrix` +outputMatrix : Tensor [SeqLen, NumTokens] Double outputMatrix = Run (SelfAttentionMat {causalMask=True}) inputMatrix params -||| Now we'll instantiate self attention as a parametric map on trees and use -||| container tensors for this. Here we'll study attention where the input -||| structure isn't a sequence, but a tree, but we'll keep the feature structure -||| as a sequence -||| That is, instead of `CTensor [Vect n, Vect d] Double` -||| we'll have `CTensor [BinTreeLeaf, Vect d] Double` -SelfAttentionTree : {d : Nat} -> - Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect d] Double -\-> - Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect d] Double +||| Now we instantiate self attention not operating on a vector of vectors +||| (i.e. a matrix), but a *tree* of vectors. We'll first define the axis +InputStructure : Axis +InputStructure = "inputStructure" ~> BinTreeLeaf + +||| We keep features as a vector, and define the self-attention map +||| Notably, we do not use any causal mask +SelfAttentionTree : Tensor [InputStructure, NumTokens] Double -\-> + Tensor [InputStructure, NumTokens] Double SelfAttentionTree = SelfAttention softargmaxImpl ||| We fix a simple input tree ||| Notably, the set of parameters can be the same as the one for matrices -inputTree : Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect 2] Double -inputTree = ># Node' (Node' (Leaf [1, -1]) - (Leaf [0.5, 1.2])) - (Leaf [-0.3, 1.2]) +inputTree : Tensor [InputStructure, NumTokens] Double +inputTree = ># Node' (Node' (Leaf [1, -1, 3, 2]) + (Node' (Leaf [0.5, 1.2, 2, 1]) + (Leaf [0.3, 4, 10, 9]))) + (Leaf [-0.3, 1.2, -13, -0.3]) ||| We can run self attention on the tree, and inspect the result -outputTree : Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect 2] Double +outputTree : Tensor [InputStructure, NumTokens] Double outputTree = Run SelfAttentionTree inputTree params \ No newline at end of file diff --git a/examples/BasicExamples.idr b/examples/BasicExamples.idr index 1072599..c2fb0a4 100644 --- a/examples/BasicExamples.idr +++ b/examples/BasicExamples.idr @@ -8,23 +8,28 @@ import Data.Tensor -- Examples of standard, cubical tensors ---------------------------------------- -||| Now you can construct Tensors directly: -t0 : Tensor ["j" ~~> 3, "k" ~~> 4] Double -t0 = ># [ [0, 1, 2, 3] - , [4, 5, 6, 7] - , [8, 9, 10, 11]] +||| This is a vector of size `5`, with its axis named "i". +myVector : Tensor ["i" ~~> 5] Double +myVector = ># [1,2,3,4,5] + +||| This is a matrix of size `3x4`, with its axes named "j" and "k". +myMatrix : Tensor ["j" ~~> 3, "k" ~~> 4] Double +myMatrix = ># [ [0, 1, 2, 3] + , [4, 5, 6, 7] + , [8, 9, 10, 11]] {-------------------- Here `>#` behaves like a constructor: it takes a concrete value and turns it into the tensor of the appropriate shape (It should be visually read as a 'map' (`>`) into 'tensor' (`#`)). You can also use functions analogous to numpy's, such as `np.arange` and `np.reshape`: --------------------} -t1 : Tensor ["i" ~~> 6] Double +t1 : Tensor ["l" ~~> 6] Double t1 = arange t2 : Tensor ["i" ~~> 2, "j" ~~> 3] Double t2 = reshape t1 + {- where the difference between numpy is that these operations are typechecked - meaning they fail at compile-time if you supply an array with the wrong shape. @@ -49,20 +54,20 @@ failing ||| You can perform all sorts of familiar numeric operations: exampleSum : Tensor ["j" ~~> 3, "k" ~~> 4] Double -exampleSum = t0 + t0 +exampleSum = myMatrix + myMatrix exampleOp : Tensor ["j" ~~> 3, "k" ~~> 4] Double -exampleOp = abs (- (t0 * t0) <&> (+7)) +exampleOp = abs (- (myMatrix * myMatrix) <&> (+7)) ||| including standard linear algebra dotExample : Tensor [] Double dotExample = dot t1 (t1 <&> (+5)) matMulExample : Tensor ["i" ~~> 2, "k" ~~> 4] Double -matMulExample = matMul t2 t0 +matMulExample = matMul t2 myMatrix transposeExample : Tensor ["k" ~~> 4, "j" ~~> 3] Double -transposeExample = transposeMatrix t0 +transposeExample = transposeMatrix myMatrix {-------------------- which all have their types checked at compile-time. For instance, you can't @@ -71,25 +76,25 @@ dimensions of matrices don't match. --------------------} failing sumFail : Tensor ["j" ~~> 3, "k" ~~> 4] Double - sumFail = t0 + t1 + sumFail = myMatrix + t1 failing matMulFail : Tensor ["i" ~~> 7] Double - matMulFail = matMul t0 t1 + matMulFail = matMul myMatrix t1 ||| Like in numpy, you can safely index into tensors, set values of tensors, and perform slicing: ||| This retrieves the value of t- at location [1,2] indexExample : Double -indexExample = t0 @@ [1, 2] +indexExample = myMatrix @@ [1, 2] -- TODO needs to be fixed --- ||| Sets the value of t0 at location [1, 3] to 99 +-- ||| Sets the value of `myMatrix` at location [1, 3] to 99 -- setExample : Tensor [3, 4] --- setExample = set t0 [1, 3] 99 +-- setExample = set `myMatrix` [1, 3] 99 -- ||| Takes the first two rows, and 1st column of t0 -- sliceExample : Tensor ["j" ~~> 2, "k" ~~> 1] Double --- sliceExample = take [2, 1] t0 +-- sliceExample = take [2, 1] `myMarix` -- Which will all fail if you go out of bounds failing @@ -98,16 +103,16 @@ failing failing sliceFail : Tensor ["j" ~~> 10, "k" ~~> 2] Double - sliceFail = take [10, 2] t0 + sliceFail = take [10, 2] myMatrix {--------------------------------------- **And most importantly, you can do all of this with *non-cubical* tensors.** These describe tensors whose shape isn't rectangular/cubical, but can be branching/recursive/higher-order. That is, instead of binding an axis name to a number, we bind it to something called a "container", by using `~>` instead of `~~>`. As a matter of fact, `~~>` behind the scenes desugars to `~>`, and we have been using this all along. -Let's see `t0` in this new form: +Let's see `myMatrix` in this new form: ---------------------------------------} -t0Again : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double -t0Again = t0 +myMatrixAgain : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double +myMatrixAgain = myMatrix {-------------------- Here `Vect` does not refer to `Vect` from `Data.Vect`, but rather the `Vect` container implemented [here](https://github.com/bgavran/TensorType/blob/main/src/Data/Container/Object/Instances.idr#L68). @@ -139,6 +144,13 @@ Unlike `Vect`, this container allows us to store an arbitrary number of elements treeExample2 : Tensor ["myTree" ~> BinTree] Double treeExample2 = ># Node 5 (Leaf 100) (Leaf 4) + +listEx : List' (Tensor ["myTree" ~> BinTree] Double) +listEx = ># [ treeExample1, treeExample2 ] + +vectEx : Vect' 2 (Vect' 3 Double) +vectEx = ># [ ># [1, 2, 3], ># [4, 5, 6.000290940000203] ] + {-------------------- Perhaps surpisingly, all linear algebra operations follow smoothly. The example below is the _dot product of trees_. The fact that these trees don't have the same number of elements is irrelevant; what matters is that the container defining them (`BinTree`) is the same. --------------------} @@ -183,6 +195,13 @@ treeExample4 = ># Leaf' (Node (Leaf [178, -43, 63]) Leaf' Leaf') + +treeExample5 : Tensor ["myTree" ~> BinTreeLeaf, "v" ~~> 2] Double +treeExample5 = ># Node' (Node' (Leaf [1, -1]) + (Node' (Leaf [0.5, 1.2]) + (Leaf [0.3, -0.2]))) + (Leaf [-0.3, 1.2]) + {-------------------- For instance, we can index into `treeExample1`: 60 @@ -218,5 +237,5 @@ Here is the in-order traversal of `treeExample1` from above. Can also use Utils.Traversals.inorder --------------------} -traversalExample : Tensor ["myTree" ~> List] Double +traversalExample : Tensor ["flattenedTree" ~> List] Double traversalExample = restructure (wrapIntoVector inorder) treeExample1 \ No newline at end of file diff --git a/src/Control/Monad/Sample/Instances.idr b/src/Control/Monad/Sample/Instances.idr index f3fc4fc..f2a73b1 100644 --- a/src/Control/Monad/Sample/Instances.idr +++ b/src/Control/Monad/Sample/Instances.idr @@ -27,9 +27,11 @@ public export ||| Computes the cumulative distribution, samples randomly, finds the right bin public export MonadSample IO where - sample @{ItIsSucc} (MkDist xs) = do - let dist : Tensor ["dist" ~~> i] Double := (softargmaxImpl {i="dist" ~~> i}) (># xs) - cumSum : Tensor ["dist" ~~> i] Double := cumulativeSum dist + sample {i = S j} (MkDist xs) = do + let dist : Tensor ["dist" ~~> S j] Double + dist = softargmaxImpl {i="dist" ~~> S j} (># xs) + cumSum : Tensor ["dist" ~~> S j] Double + cumSum = cumulativeSum dist r <- randomRIO (0.0, 1.0) case findBin (#> cumSum) r of Nothing => pure FZ -- should never happen! diff --git a/src/Data/CT/Category/Instances.idr b/src/Data/CT/Category/Instances.idr index be48f96..0ab9211 100644 --- a/src/Data/CT/Category/Instances.idr +++ b/src/Data/CT/Category/Instances.idr @@ -3,6 +3,7 @@ module Data.CT.Category.Instances import Data.CT.Category.Definition import Data.CT.Functor.Definition +import Data.ComMonoid import Data.Container.Base import Data.Container.Additive @@ -34,4 +35,9 @@ AddDLens = MkCat AddCont (=%>) ||| Category of additive dependent charts public export AddDChart : Cat -AddDChart = MkCat AddCont (=&>) \ No newline at end of file +AddDChart = MkCat AddCont (=&>) + +||| Category of commutative monoids and commutative monoid homomorphisms +public export +ComMon : Cat +ComMon = MkCat ComMonoid ComMonoidHomo diff --git a/src/Data/CT/DependentAction/Instances.idr b/src/Data/CT/DependentAction/Instances.idr index 0acc587..86a8ee3 100644 --- a/src/Data/CT/DependentAction/Instances.idr +++ b/src/Data/CT/DependentAction/Instances.idr @@ -34,7 +34,7 @@ namespace Cont public export DPairCont : DepAct DLens (FamDLens {c=DLens}) DPairCont = MkDepAct $ \c => MkFunctor - (DepHancockProduct c) + (DPairTensor c) (\r => !% \(x ** p) => ((x ** (r x).fwd p) ** \(x', p') => (x', (r x).bwd p p'))) @@ -49,6 +49,6 @@ namespace AddCont public export DPairAddCont : DepAct AddDLens (FamAddDLens {c=AddDLens}) DPairAddCont = MkDepAct $ \c => MkFunctor - (DepHancockProduct c) + (DPair c) (\r => !%+ \(x ** p) => ((x ** (r x).fwd p) ** \(x', p') => (x', (r x).bwd p p'))) \ No newline at end of file diff --git a/src/Data/CT/DependentPara/Instances.idr b/src/Data/CT/DependentPara/Instances.idr index a0ea2ab..1341191 100644 --- a/src/Data/CT/DependentPara/Instances.idr +++ b/src/Data/CT/DependentPara/Instances.idr @@ -194,7 +194,7 @@ namespace DependentParametricDependentLenses public export composePara : a =\\=> b -> b =\\=> c -> a =\\=> c composePara (MkPara p f) (MkPara q g) = MkPara - (\x => DepHancockProduct (p x) (\ps => q (f.fwd (x ** ps)))) + (\x => DPair (p x) (\ps => q (f.fwd (x ** ps)))) (!%+ \(x ** (ps ** qs)) => (g.fwd (f.fwd (x ** ps) ** qs) ** \cPos => let (bPos, qPos) = g.bwd (f.fwd (x ** ps) ** qs) cPos @@ -211,12 +211,12 @@ namespace DependentParametricDependentLenses ||| a non-dependent (constant) parameter. public export data IsNotDependent : DParaAddDLens a b -> Type where - MkNonDep : (p : AddCont) -> (f : DepHancockProduct a (const p) =%> b) -> + MkNonDep : (p : AddCont) -> (f : DPair a (const p) =%> b) -> IsNotDependent {a=a} {b=b} (MkPara (\_ => p) f) public export GetNonDep : (pf : DParaAddDLens a b) -> - IsNotDependent pf => (pc : AddCont ** DepHancockProduct a (const pc) =%> b) + IsNotDependent pf => (pc : AddCont ** DPair a (const pc) =%> b) GetNonDep _ @{MkNonDep pc f} = (pc ** f) public export @@ -237,11 +237,11 @@ namespace DependentParametricDependentLenses composeNTimes 1 f = f -- to get rid of the annoying Unit parameter composeNTimes (S k) f = composePara f (composeNTimes k f) - ||| Convert a morphism from product container to one from DepHancockProduct - ||| This witnesses the isomorphism (a >< p) ≅ DepHancockProduct a (const p) + ||| Convert a morphism from product container to one from DPair + ||| This witnesses the isomorphism (a >< p) ≅ DPair a (const p) public export fromNonDepProduct : {0 a, p, b : AddCont} -> - (a >< p) =%> b -> DepHancockProduct a (const p) =%> b + (a >< p) =%> b -> DPair a (const p) =%> b fromNonDepProduct f = !%+ \(x ** p') => (%!) f (x, p') public export diff --git a/src/Data/CT/Functor/Instances.idr b/src/Data/CT/Functor/Instances.idr index 6fca2c6..b233445 100644 --- a/src/Data/CT/Functor/Instances.idr +++ b/src/Data/CT/Functor/Instances.idr @@ -87,6 +87,7 @@ namespace Type TypeDFun fam = (x : a) -> fam x namespace Cont + ||| TODO probably name clash with other "Indexed container" public export IndexedCont : Cont -> Type IndexedCont c = c.Shp -> Cont @@ -180,42 +181,4 @@ trivialFam = \_ => ((_ : ()) !> Void) -- Family that assigns Bool shapes with no positions public export boolFam : {c : Cont} -> IndexedCont c -boolFam = \_ => ((_ : Bool) !> Void) - --------------------------------------------------------------------------------- --- COMPARISON WITH CHARTS/LENSES --- --- A chart `c =&> d` is a MORPHISM in Poly, involving both shapes AND positions. --- A family `c.Shp -> Cont` is just a FUNCTION on shapes. --- --- Charts give richer structure (relating positions), but families are simpler --- and sufficient for forming dependent sums and products. --- --- You CAN extract a family from certain charts: --------------------------------------------------------------------------------- - --- The "universe" container: shapes are types, positions are their elements -public export -ContUniverse : Cont -ContUniverse = (t : Type) !> t - --- From a chart to Universe, extract just the shape-level data as a family --- (This loses the position-level information from the chart!) -public export -famFromChart : (c : Cont) -> (f : c =&> ContUniverse) -> IndexedCont c -famFromChart c f = \s => ((_ : f.fwd s) !> Void) - -- We get family shapes from f.fwd, but lose the f.bwd data - -- The chart's f.bwd : (s : c.Shp) -> c.Pos s -> f.fwd s - -- doesn't fit into ContFam's structure - --------------------------------------------------------------------------------- --- WHY Poly ISN'T LOCALLY CARTESIAN CLOSED --- --- In a locally cartesian closed category, for any morphism f : A -> B, --- the pullback functor f* : C/B -> C/A has a right adjoint Πf. --- --- In Poly, this fails for general dependent lenses. The right adjoint --- only exists for CARTESIAN morphisms (where the backward map is an iso). --- --- Reference: von Glehn thesis, Section 4.3 --------------------------------------------------------------------------------- +boolFam = \_ => ((_ : Bool) !> Void) \ No newline at end of file diff --git a/src/Data/ComMonoid.idr b/src/Data/ComMonoid.idr index c8552c2..2b8f6f0 100644 --- a/src/Data/ComMonoid.idr +++ b/src/Data/ComMonoid.idr @@ -51,10 +51,15 @@ namespace NotExposingType ComMonoid : Type ComMonoid = (t : Type ** ComMonoid t) + ||| Not encoding the rules for now public export - record ComMonoidHomo (c, d : ComMonoid) where - constructor MkComMonoidHomo - underlyingMap : c.fst -> d.fst - plusPreserve : (x, y : c.fst) -> - underlyingMap (c.snd.plus x y) = d.snd.plus (underlyingMap x) (underlyingMap y) - neutralPreserve : underlyingMap c.snd.neutral = d.snd.neutral + ComMonoidHomo : ComMonoid -> ComMonoid -> Type + ComMonoidHomo (t ** _) (t' ** _) = t -> t' + + -- public export + -- record ComMonoidHomo (c, d : ComMonoid) where + -- constructor MkComMonoidHomo + -- underlyingMap : c.fst -> d.fst + -- plusPreserve : (x, y : c.fst) -> + -- underlyingMap (c.snd.plus x y) = d.snd.plus (underlyingMap x) (underlyingMap y) + -- neutralPreserve : underlyingMap c.snd.neutral = d.snd.neutral diff --git a/src/Data/Container/Additive.idr b/src/Data/Container/Additive.idr index 2741044..a6b8ac0 100644 --- a/src/Data/Container/Additive.idr +++ b/src/Data/Container/Additive.idr @@ -11,6 +11,7 @@ import public Data.Container.Additive.Object.Definition import public Data.Container.Additive.Morphism.Definition import public Data.Container.Additive.Extension.Definition import public Data.Container.Additive.Product.Definitions +import public Data.Container.Additive.Properties.Definitions import public Data.Container.Additive.Object.Instances import public Data.Container.Additive.Morphism.Instances \ No newline at end of file diff --git a/src/Data/Container/Additive/Extension/Definition.idr b/src/Data/Container/Additive/Extension/Definition.idr index 52de334..0cdc19c 100644 --- a/src/Data/Container/Additive/Extension/Definition.idr +++ b/src/Data/Container/Additive/Extension/Definition.idr @@ -10,4 +10,7 @@ Ext c x = Ext (UC c) x ||| Can be represented as a derivative public export Path : AddCont -> Type -Path c = (x : c.Shp ** c.Pos x) \ No newline at end of file +Path c = (x : c.Shp ** c.Pos x) + +ghh : Ext UnitCont Double +ghh = () <| ?rrr diff --git a/src/Data/Container/Additive/Morphism/Definition.idr b/src/Data/Container/Additive/Morphism/Definition.idr index aa3e403..389827a 100644 --- a/src/Data/Container/Additive/Morphism/Definition.idr +++ b/src/Data/Container/Additive/Morphism/Definition.idr @@ -71,7 +71,7 @@ namespace DependentLenses lensInputs : {c, d : AddCont} -> c =%> d -> AddCont lensInputs lens = MkAddCont (lensInputs (ULens lens)) - {mon=(MkI @{\s => ?lensInputsMon_rhs})} + {mon=(MkI $ \s => ?lensInputsMon_rhs)} namespace DependentCharts diff --git a/src/Data/Container/Additive/Morphism/Instances.idr b/src/Data/Container/Additive/Morphism/Instances.idr index c5451a2..e347d04 100644 --- a/src/Data/Container/Additive/Morphism/Instances.idr +++ b/src/Data/Container/Additive/Morphism/Instances.idr @@ -11,6 +11,7 @@ import Data.Container.Additive.Object.Definition import Data.Container.Additive.Object.Instances import Data.Container.Additive.Morphism.Definition import Data.Container.Additive.Product.Definitions +import Data.Container.Additive.Properties.Definitions import Data.Container.Additive.Quantifiers @@ -19,7 +20,7 @@ import Control.Monad.Sample.Definition import Misc -%hide Data.Container.Base.Object.Definition.Const +%hide Data.Container.Base.Object.Instances.Const %hide Data.Vect.Quantifiers.All.index public export @@ -295,6 +296,6 @@ coAlgMorphism c d = c.carrier =%> d.carrier convert : FCoAlgCont List -> AddCont convert (MkFCoAlgCont carrier coalg) = MkAddCont carrier - {mon=(MkI @{\s => MkComMonoid + {mon=(MkI $ \s => MkComMonoid (\l, r => coalg s [l, r]) - (coalg s [])})} \ No newline at end of file + (coalg s []))} \ No newline at end of file diff --git a/src/Data/Container/Additive/Object/Definition.idr b/src/Data/Container/Additive/Object/Definition.idr index dd2a3ca..a92fe80 100644 --- a/src/Data/Container/Additive/Object/Definition.idr +++ b/src/Data/Container/Additive/Object/Definition.idr @@ -28,7 +28,7 @@ public export ||| Underlying monoid structure of positions public export UMon : (c : AddCont) -> (s : c.Shp) -> ComMonoid (c.Pos s) -UMon (MkAddCont c @{MkI @{m}}) s = m s +UMon (MkAddCont c @{MkI m}) s = m s public export (.Plus) : (c : AddCont) -> (s : c.Shp) -> (c.Pos s -> c.Pos s -> c.Pos s) @@ -36,20 +36,4 @@ public export public export (.Zero) : (c : AddCont) -> (s : c.Shp) -> c.Pos s -(.Zero) c s = neutral (UMon c s) - -||| Convenience datatype storing the property that -||| an additive container `c` has an interface `i` on its positions -public export -data InterfaceOnPositions : (c : AddCont) -> (i : Type -> Type) -> Type where - ||| For every shape s the set of positions c.Pos s has that interface - MkI : (p : (s : c.Shp) -> i (c.Pos s)) => - InterfaceOnPositions c i - - -namespace Flat - public export - data IsFlat : AddCont -> Type where - MkIsFlat : (p : Type) -> (mon : ComMonoid p) => IsFlat (MkAddCont (Const p)) - - --flatEq : IsFlat c => c = MkAddCont (Const c.Shp) \ No newline at end of file +(.Zero) c s = neutral (UMon c s) \ No newline at end of file diff --git a/src/Data/Container/Additive/Object/Instances.idr b/src/Data/Container/Additive/Object/Instances.idr index 41b32b8..a08cbd8 100644 --- a/src/Data/Container/Additive/Object/Instances.idr +++ b/src/Data/Container/Additive/Object/Instances.idr @@ -7,17 +7,27 @@ import Data.Container.Base import Data.ComMonoid import Data.Container.Additive.Object.Definition import Data.Container.Additive.Extension.Definition +import Data.Container.Additive.Product.Definitions ||| Scalar additive container +||| This is equivalent to `!! UnitCont` +||| Unit of the categorical product public export Scalar : AddCont Scalar = MkAddCont Scalar +||| Empty additive container +||| Unit of the coproduct +||| Initial container +public export +Empty : AddCont +Empty = MkAddCont Empty @{MkI absurd} + ||| Constant additive container, positions not dependent on shapes ||| Allows the backward part to be different than forward one public export Const : Type -> ComMonoid -> AddCont -Const a (t ** m) = MkAddCont (Const2 a t) @{MkI @{\_ => m}} +Const a (t ** m) = MkAddCont (Const2 a t) @{MkI $ \_ => m} public export TrivialPos : Type -> AddCont diff --git a/src/Data/Container/Additive/Product/Definitions.idr b/src/Data/Container/Additive/Product/Definitions.idr index adaca1a..507f3c6 100644 --- a/src/Data/Container/Additive/Product/Definitions.idr +++ b/src/Data/Container/Additive/Product/Definitions.idr @@ -9,7 +9,6 @@ import Data.Num import Data.Container.Additive.Object.Definition import Data.Container.Additive.Morphism.Definition import Data.Container.Additive.Extension.Definition -import Data.Container.Additive.Object.Instances import Data.Container.Base.Quantifiers import Data.Container.Additive.Quantifiers @@ -40,13 +39,15 @@ Compared to ordinary containers, for additive containers: -------------------------------------------------------------------------------} ||| Hancock tensor product here becomes the categorical product +||| Monoid with Scalar namespace Product + ||| Binary version of product public export (><) : AddCont -> AddCont -> AddCont c >< d = MkAddCont (UC c >< UC d) - @{MkI @{\sh => MkComMonoid (\l, r => + @{MkI $ \sh => MkComMonoid (\l, r => (c.Plus (fst sh) (fst l) (fst r), d.Plus (snd sh) (snd l) (snd r))) - (c.Zero (fst sh), d.Zero (snd sh))}} + (c.Zero (fst sh), d.Zero (snd sh))} ||| Can also use the product operator public export @@ -59,7 +60,7 @@ namespace Product AllAll : List AddCont -> AddCont AllAll xs = MkAddCont ((shapes : All (.Shp) xs) !> AllPos shapes) - @{MkI @{allPosComMonoid}} + @{MkI allPosComMonoid} namespace Vect ||| N-ary version of hancock product @@ -67,7 +68,7 @@ namespace Product AllAll : Vect n AddCont -> AddCont AllAll xs = MkAddCont ((shapes : All (.Shp) xs) !> AllPos shapes) - @{MkI @{allPosComMonoid}} + @{MkI allPosComMonoid} namespace Morphism public export @@ -76,30 +77,31 @@ namespace Product (><) f g = !%+ \(c, d) => ((f.fwd c, g.fwd d) ** \(c', d') => (f.bwd c c', g.bwd d d')) - ||| Dependent Hancock (tensor) product of additive containers. - ||| This is the analogue of DPair for containers: - ||| Given a container `pc` and a family `qc : pc.Shp -> AddCont`, + ||| Dependent pair type for additive containers + ||| Can be thought of as the dependent tensor product of containers + ||| Given a container `s` and a family `p : s.Shp -> Cont`, ||| form the container whose shapes are dependent pairs of shapes ||| and positions are pairs of positions. public export - DepHancockProduct : (pc : AddCont) -> (qc : pc.Shp -> AddCont) -> AddCont - DepHancockProduct pc qc = MkAddCont - (DepHancockProduct (UC pc) (UC . qc)) - @{MkI @{\(ps ** qs) => MkComMonoid + DPair : (pc : AddCont) -> (qc : pc.Shp -> AddCont) -> AddCont + DPair pc qc = MkAddCont + (DPairTensor (UC pc) (UC . qc)) + @{MkI $ \(ps ** qs) => MkComMonoid (\(pcPos1, qcPos1), (pcPos2, qcPos2) => (plus (UMon pc ps) pcPos1 pcPos2, plus (UMon (qc ps) qs) qcPos1 qcPos2)) - (neutral (UMon pc ps), neutral (UMon (qc ps) qs))}} + (neutral (UMon pc ps), neutral (UMon (qc ps) qs))} ||| Same as in ordinary containers +||| Monoid with Empty namespace Coproduct ||| Coproduct public export (>+<) : AddCont -> AddCont -> AddCont c >+< d = MkAddCont (UC c >+< UC d) - @{MkI @{\case + @{MkI $ \case Left cs => MkComMonoid (plus (UMon c cs)) (neutral (UMon c cs)) - Right ds => MkComMonoid (plus (UMon d ds)) (neutral (UMon d ds))}} + Right ds => MkComMonoid (plus (UMon d ds)) (neutral (UMon d ds))} namespace Morphism public export @@ -115,7 +117,7 @@ namespace Coproduct Any : List AddCont -> AddCont Any xs = MkAddCont ((shapes : Any (.Shp) xs) !> AnyShpPos shapes) - @{MkI @{anyShpPosComMonoid}} + @{MkI anyShpPosComMonoid} namespace Vect ||| N-ary version of coproduct @@ -123,7 +125,7 @@ namespace Coproduct Any : Vect n AddCont -> AddCont Any xs = MkAddCont ((shapes : Any (.Shp) xs) !> AnyShpPos shapes) - @{MkI @{anyShpPosComMonoid}} + @{MkI anyShpPosComMonoid} ||| With an ordinary container `c`, the Pi and Sigma type simple are the ||| dependent function ((s : c.Shp) -> c.Pos s) and the dependent pair @@ -166,7 +168,7 @@ public export (!!) : Cont -> AddCont (!!) c = MkAddCont (List c) - @{MkI @{\_ => listIsMonoid}} + @{MkI $ \_ => listIsMonoid} export prefix 9 !! @@ -222,7 +224,7 @@ namespace MonoidalClosure InternalLensAdditive : AddCont -> AddCont -> AddCont InternalLensAdditive c d = MkAddCont ((l : c =%> d) !> List (Path (lensInputs l))) - @{MkI @{\_ => listIsMonoid}} + @{MkI $ \_ => listIsMonoid} public export curry : {c : AddCont} -> (c >< d) =%> e -> c =%> (InternalLensAdditive d e) @@ -260,7 +262,7 @@ public export List : AddCont -> AddCont List c = MkAddCont (List (UC c)) - @{MkI @{allIsComonoid}} + @{MkI allIsComonoid} namespace Morphism diff --git a/src/Data/Container/Additive/Properties/Definitions.idr b/src/Data/Container/Additive/Properties/Definitions.idr new file mode 100644 index 0000000..63aae6f --- /dev/null +++ b/src/Data/Container/Additive/Properties/Definitions.idr @@ -0,0 +1,23 @@ +module Data.Container.Additive.Properties.Definitions + +import Data.Container.Base +import Data.Container.Additive.Object.Definition + +import Data.ComMonoid + + +||| Convenience datatype storing the property that +||| an additive container `c` has an interface `i` on its positions +public export +data InterfaceOnPositions : (c : AddCont) -> (i : Type -> Type) -> Type where + ||| For every shape s the set of positions c.Pos s has that interface + MkI : (p : (s : c.Shp) -> i (c.Pos s)) => + InterfaceOnPositions c i + + +namespace Flat + public export + data IsFlat : AddCont -> Type where + MkIsFlat : (p : Type) -> (mon : ComMonoid p) => IsFlat (MkAddCont (Const p)) + + --flatEq : IsFlat c => c = MkAddCont (Const c.Shp) \ No newline at end of file diff --git a/src/Data/Container/Applicative.idr b/src/Data/Container/Applicative.idr index c72a924..2df96d0 100644 --- a/src/Data/Container/Applicative.idr +++ b/src/Data/Container/Applicative.idr @@ -4,7 +4,7 @@ import public Data.Container.Base import public Data.Container.Applicative.Object.Instances import public Data.Container.Applicative.Extension.Instances -import public Data.Container.Applicative.Concrete.Instances +import public Data.Container.Applicative.Properties.Instances import public Data.Container.Applicative.Product.Interfaces import public Data.Container.Applicative.TreeUtils \ No newline at end of file diff --git a/src/Data/Container/Applicative/Concrete/Instances.idr b/src/Data/Container/Applicative/Properties/Instances.idr similarity index 94% rename from src/Data/Container/Applicative/Concrete/Instances.idr rename to src/Data/Container/Applicative/Properties/Instances.idr index 39c4a22..0ef6570 100644 --- a/src/Data/Container/Applicative/Concrete/Instances.idr +++ b/src/Data/Container/Applicative/Properties/Instances.idr @@ -1,4 +1,4 @@ -module Data.Container.Applicative.Concrete.Instances +module Data.Container.Applicative.Properties.Instances import Data.Container.Base import Data.Container.Applicative.Object.Instances @@ -30,9 +30,8 @@ toRoseTreeSame (NodeS (len <| content) <| contentAt) <$> (\i => content i <| contentAt . SubTree i) <$> positionsCont) - public export -FromConcrete RoseTree where +IsConcrete RoseTree where concreteType = RoseTreeSame concreteFunctor = %search fromConcreteTy = fromRoseTreeSame diff --git a/src/Data/Container/Base.idr b/src/Data/Container/Base.idr index 1e08529..f78b441 100644 --- a/src/Data/Container/Base.idr +++ b/src/Data/Container/Base.idr @@ -9,11 +9,12 @@ import public Data.Container.Base.Definitions import public Data.Container.Base.Instances -- for manipulating concrete tree instances -import public Data.Tree +import public Data.Trees import public Data.Functor.Algebra -- import public Data.Functor.Naperian import public Data.Container.Base.TreeUtils +import public Data.Container.Base.Display2D.Display2D -- where to put -- temp/misc import public Data.Container.SubTerm diff --git a/src/Data/Container/Base/Concrete/Definition.idr b/src/Data/Container/Base/Concrete/Definition.idr deleted file mode 100644 index d345322..0000000 --- a/src/Data/Container/Base/Concrete/Definition.idr +++ /dev/null @@ -1,32 +0,0 @@ -module Data.Container.Base.Concrete.Definition - -import Data.Container.Base.Object.Definition -import Data.Container.Base.Extension.Definition - --- todo rename to `IsConcrete? -||| Many Idris' datatypes are already concrete, inductive -||| representations of particular containers -||| It is useful to be easily able to convert between them -public export -interface FromConcrete (cont : Cont) where - constructor MkConcrete - concreteType : Type -> Type - concreteFunctor : Functor concreteType - fromConcreteTy : concreteType a -> Ext cont a - toConcreteTy : Ext cont a -> concreteType a - - -public export -data AllConcrete : List Cont -> Type where - Nil : AllConcrete [] - Cons : (firstConcrete : FromConcrete c) => - (restConcrete : AllConcrete cs) => - AllConcrete (c :: cs) - - --- public export --- fromConcreteMap : {cont1, cont2 : Cont} -> --- (fc1 : FromConcrete cont1) => (fc2 : FromConcrete cont2) => --- (concreteType @{fc1} a -> concreteType @{fc2} b) -> --- cont1 `fullOf` a -> cont2 `fullOf` b --- fromConcreteMap f = fromConcrete @{fc2} . f . toConcrete @{fc1} diff --git a/src/Data/Container/Base/Concrete/Instances.idr b/src/Data/Container/Base/Concrete/Instances.idr deleted file mode 100644 index 9e13113..0000000 --- a/src/Data/Container/Base/Concrete/Instances.idr +++ /dev/null @@ -1,183 +0,0 @@ -module Data.Container.Base.Concrete.Instances - -import Data.Fin -import Data.Vect -import Data.List -import Data.DPair - -import Data.Container.Base.Object.Definition -import Data.Container.Base.Extension.Definition -import Data.Container.Base.Concrete.Definition -import Data.Container.Base.Object.Instances -import Data.Container.Base.Extension.Instances - -import Data.Container.Base.TreeUtils - --- import public Data.Functor.Naperian -import public Data.Tree - -import Misc - -%hide Data.Vect.fromList - -namespace ConversionFunctions - public export - toScalar : a -> Scalar' a - toScalar a = () <| (\_ => a) - - public export - extract : Scalar' a -> a - extract (() <| f) = f () - - public export - fromMaybe : Maybe a -> Maybe' a - fromMaybe Nothing = (False <| absurd) - fromMaybe (Just a) = (True <| \_ => a) - - public export - toMaybe : Maybe' a -> Maybe a - toMaybe (False <| absurd) = Nothing - toMaybe (True <| f) = Just (f ()) - - public export - fromList : List a -> List' a - fromList [] = (0 <| absurd) - fromList (x :: xs) = let (l <| c) = fromList xs - in (S l <| cons x c) - - public export - toList : List' a -> List a - toList (0 <| _) = [] - toList ((S k) <| ind) = head ind :: toList (k <| tail ind) - - public export - fromVect : Vect n a -> Vect' n a - fromVect v = () <| \i => index i v - - public export - toVect : {n : Nat} -> Vect' n a -> Vect n a - toVect (_ <| index) = Vect.Fin.tabulate index - - public export - fromBinTreeSame : BinTreeSame a -> BinTree' a - fromBinTreeSame (Leaf x) = LeafS <| \_ => x - fromBinTreeSame (Node x lt rt) = - let fblt = fromBinTreeSame lt - fbrt = fromBinTreeSame rt - in NodeS (shapeExt fblt) (shapeExt fbrt) <| \case - AtNode => x - GoLeft posL => index fblt posL - GoRight posR => index fbrt posR - - public export - toBinTreeSame : BinTree' a -> BinTreeSame a - toBinTreeSame (LeafS <| index) = Leaf (index AtLeaf) - toBinTreeSame (NodeS lt rt <| index) = - Node (index AtNode) - (toBinTreeSame (lt <| index . GoLeft)) - (toBinTreeSame (rt <| index . GoRight)) - - - public export - fromTreeHelper : BinTreePosNode LeafS -> a - fromTreeHelper AtNode impossible - fromTreeHelper (GoLeft x) impossible - fromTreeHelper (GoRight x) impossible - - public export - fromBinTreeNode : BinTreeNode a -> BinTreeNode' a - fromBinTreeNode (Leaf ()) = LeafS <| fromTreeHelper - fromBinTreeNode (Node node leftTree rightTree) - = let fblt = fromBinTreeNode leftTree - fbrt = fromBinTreeNode rightTree - in (NodeS (shapeExt fblt) (shapeExt fbrt) <| \case - AtNode => node - GoLeft posL => index fblt posL - GoRight posR => index fbrt posR) - - public export - toBinTreeNode : BinTreeNode' a -> BinTreeNode a - toBinTreeNode (LeafS <| index) = Leaf () - toBinTreeNode (NodeS lt rt <| index) = - Node (index AtNode) - (toBinTreeNode (lt <| index . GoLeft)) - (toBinTreeNode (rt <| index . GoRight)) - - public export - fromBinTreeLeaf : BinTreeLeaf a -> BinTreeLeaf' a - fromBinTreeLeaf (Leaf leaf) = LeafS <| \_ => leaf - fromBinTreeLeaf (Node node lt rt) = - let fblt = fromBinTreeLeaf lt - fbrt = fromBinTreeLeaf rt - in NodeS (shapeExt fblt) (shapeExt fbrt) <| \case - GoLeft posL => index fblt posL - GoRight posR => index fbrt posR - - public export - toBinTreeLeaf : BinTreeLeaf' a -> BinTreeLeaf a - toBinTreeLeaf (LeafS <| content) = Leaf (content AtLeaf) - toBinTreeLeaf (NodeS l r <| content) = - Node' (toBinTreeLeaf (l <| content . GoLeft)) - (toBinTreeLeaf (r <| content . GoRight)) - - -- ||| Indexing an element of `xs` and then applying `f` to it is the same as - -- ||| mapping `f` over xs, and then indexing the result - -- public export - -- mapIndexPreserve : {0 f : a -> b} -> - -- (xs : List a) -> - -- (i : Fin (length (f <$> xs))) -> - -- f (index' xs (rewrite sym (lengthMap {f=f} xs) in i)) - -- = index' (f <$> xs) i - -- mapIndexPreserve (x :: xs) FZ = Refl - -- mapIndexPreserve (x :: xs) (FS j) = mapIndexPreserve xs j - - -public export -FromConcrete Scalar where - concreteType = id - concreteFunctor = MkFunctor id - fromConcreteTy = pure - toConcreteTy = extract - -public export -FromConcrete Maybe where - concreteType = Maybe - concreteFunctor = %search - fromConcreteTy = fromMaybe - toConcreteTy = toMaybe - -public export -FromConcrete List where - concreteType = List - concreteFunctor = %search -- TODO how to find the result of the search and place it here directly? - fromConcreteTy = fromList - toConcreteTy = toList - -public export -{n : Nat} -> FromConcrete (Vect n) where - concreteType = Vect n - concreteFunctor = %search - fromConcreteTy = fromVect - toConcreteTy = toVect - -public export -FromConcrete BinTree where - concreteType = BinTreeSame - concreteFunctor = %search - fromConcreteTy = fromBinTreeSame - toConcreteTy = toBinTreeSame - -public export -FromConcrete BinTreeNode where - concreteType = BinTreeNode - concreteFunctor = %search - fromConcreteTy = fromBinTreeNode - toConcreteTy = toBinTreeNode - -public export -FromConcrete BinTreeLeaf where - concreteType = BinTreeLeaf - concreteFunctor = %search - fromConcreteTy = fromBinTreeLeaf - toConcreteTy = toBinTreeLeaf - diff --git a/src/Data/Container/Base/Definitions.idr b/src/Data/Container/Base/Definitions.idr index 13a4e70..aeb2d91 100644 --- a/src/Data/Container/Base/Definitions.idr +++ b/src/Data/Container/Base/Definitions.idr @@ -4,5 +4,5 @@ module Data.Container.Base.Definitions import public Data.Container.Base.Object.Definition import public Data.Container.Base.Extension.Definition import public Data.Container.Base.Morphism.Definition -import public Data.Container.Base.Concrete.Definition +import public Data.Container.Base.Properties.Definitions import public Data.Container.Base.Product.Definitions diff --git a/src/Data/Container/Base/Display2D/CharacterMap.idr b/src/Data/Container/Base/Display2D/CharacterMap.idr new file mode 100644 index 0000000..f7b4669 --- /dev/null +++ b/src/Data/Container/Base/Display2D/CharacterMap.idr @@ -0,0 +1,77 @@ +module Data.Container.Base.Display2D.CharacterMap + +public export +padCharacter : Char +padCharacter = ' ' + +public export +record Tree where + constructor MkTree + horizontal : Char + branchMid : Char + branchLast : Char + vertical : Char + gap : Char + placeholder : Char + +||| Used for drawing a box border +public export +record Box where + constructor MkBox + topLeft : Char + topRight : Char + bottomLeft : Char + bottomRight : Char + horizontal : Char + vertical : Char + +||| Used for rendering list-like syntax +public export +record ListSyntax where + constructor MkListSyntax + left : Char + right : Char + separator : Char + +||| Used for rendering pair-like syntax +public export +record PairSyntax where + constructor MkPairSyntax + left : Char + right : Char + separator : Char + +public export +SingleLineTree : Tree +SingleLineTree = MkTree + '─' + '├' + '└' + '│' + ' ' + '\x00B7' + +||| Double-line, to separate from tree lines +public export +DoubleLineBox : Box +DoubleLineBox = MkBox + '╔' + '╗' + '╚' + '╝' + '═' + '║' + +public export +AsciiListSyntax : ListSyntax +AsciiListSyntax = MkListSyntax + '[' + ']' + ',' + +public export +AsciiPairSyntax : PairSyntax +AsciiPairSyntax = MkPairSyntax + '(' + ')' + ',' \ No newline at end of file diff --git a/src/Data/Container/Base/Display2D/Display2D.idr b/src/Data/Container/Base/Display2D/Display2D.idr new file mode 100644 index 0000000..4960633 --- /dev/null +++ b/src/Data/Container/Base/Display2D/Display2D.idr @@ -0,0 +1,558 @@ +module Data.Container.Base.Display2D.Display2D + +import Data.Fin +import Data.List +import Data.Vect + + +import Data.Container.Base.Object.Definition +import Data.Container.Base.Extension.Definition +import Data.Container.Base.Product.Definitions + +import Data.Container.Base.Instances + +import public Data.Container.Base.Display2D.CharacterMap +import Data.ScientificNotation + +import Data.Container.Base.TreeUtils +import Misc + +%hide Syntax.WithProof.prefix.(@@) + +{------------------------------------------------------------------------------- +Machinery for rendering values as rectangular 2D character grids. + +Layout combinators (`besideAllGap`, `aboveAllSep`, `addBorderToGrid`, ...) +build new grids by (1) computing the output shape from the inputs at the data +level, and (2) building a new index function that dispatches into the inputs. +Both steps are O(1). Lookup cost in a deeply composed grid is proportional +to the composition depth; `gridRows` (used by `showGrid`) is the one-shot +materialisation point that walks each cell exactly once. + +As tensor utilities are added, functionality within this file will be rewritten +-------------------------------------------------------------------------------} + +||| Approximate line-width budget for printed cubical tensors. When a row of +||| fixed-width cells does not fit on a single line, the renderer breaks it +||| onto continuation lines so output stays within roughly this many columns. +||| Matches NumPy's default `linewidth = 75`. +public export +defaultLineWidth : Nat +defaultLineWidth = 75 + +{------------------------------------------------------------------------------- +Core: Grid type, accessors, basic constructors +-------------------------------------------------------------------------------} + +||| A `Grid a` is a rectangular 2D matrix. Unlike `List >@ List`, which +||| produces ragged arrays, this produces rectangular ones +public export +Grid : Type -> Type +Grid = Ext (List >< List) + +public export +gridHeight : Grid a -> Nat +gridHeight ((h, _) <| _) = h + +public export +gridWidth : Grid a -> Nat +gridWidth ((_, w) <| _) = w + +||| Look up a cell. Cost depends on how `g` was built; for grids made from +||| chained combinators the cost is proportional to the composition depth. +public export +gridIndex : (g : Grid a) -> Fin (gridHeight g) -> Fin (gridWidth g) -> a +gridIndex ((_, _) <| f) i j = f (i, j) + +||| Build a grid from a height, width, and 2D index function. O(1). +public export +mkGrid : (h, w : Nat) -> (Fin h -> Fin w -> a) -> Grid a +mkGrid h w f = (h, w) <| uncurry f + + +{------------------------------------------------------------------------------- +Materialisation +-------------------------------------------------------------------------------} + +||| Convert a `Grid` into a list-of-rows view. Walks the grid once, so this +||| is the natural one-shot materialisation point for delayed grids. +public export +gridRows : Grid a -> List (List a) +gridRows g = + toList' $ Fin.tabulate {len = gridHeight g} $ \i => + toList' $ Fin.tabulate {len = gridWidth g} $ \j => + gridIndex g i j + +||| Format a `Grid Char` as a multi-line string. Trailing pad characters on +||| each line are stripped, matching NumPy's printing conventions. +public export +showGrid : Grid Char -> String +showGrid = + concat . intersperse "\n" . map (pack . dropFromEnd padCharacter) . gridRows + + +{------------------------------------------------------------------------------- +Primitive constructors +-------------------------------------------------------------------------------} + +public export +emptyGrid : Grid a +emptyGrid = mkGrid 0 0 absurd + +public export +singleValue : a -> Grid a +singleValue v = mkGrid 1 1 (\_, _ => v) + +public export +uniformCol : a -> Nat -> Grid a +uniformCol v h = mkGrid h 1 (\_, _ => v) + +public export +blankCol : Nat -> Grid Char +blankCol = uniformCol padCharacter + +||| One-row grid built from a list of cells. +public export +rowGrid : List a -> Grid a +rowGrid xs = mkGrid 1 (length xs) (\_, j => index j (fromList xs)) + +||| 1-wide column carrying `marker` on its top row and `padCharacter` elsewhere. +||| `emptyGrid` if the height is zero. +public export +topMarkerCol : (marker : Char) -> (h : Nat) -> Grid Char +topMarkerCol _ Z = emptyGrid +topMarkerCol c (S k) = mkGrid (S k) 1 $ \i, _ => case i of + FZ => c + _ => padCharacter + +||| 1-wide column carrying `marker` on its bottom row and `padCharacter` +||| elsewhere. `emptyGrid` if the height is zero. +public export +bottomMarkerCol : (marker : Char) -> (h : Nat) -> Grid Char +bottomMarkerCol _ Z = emptyGrid +bottomMarkerCol c (S k) = mkGrid (S k) 1 $ \i, _ => case i == last of + True => c + False => padCharacter + + +{------------------------------------------------------------------------------- +Multiway layout combinators + +Each takes O(1) to construct and produces a delayed grid whose index function +linearly scans the inputs. For typical fan-outs (2..16) this is plenty fast. +-------------------------------------------------------------------------------} + +||| Locate global position `i` in a sequence of items each of width `size g`, +||| with `gap` blank positions between consecutive items. Returns the item +||| together with a properly bounded local index, or `Nothing` if `i` lands in +||| a gap or past the end. +||| +||| Total: recursion is structural on the list of grids. +locateChunk : (size : Grid a -> Nat) -> + (gap : Nat) -> + (gs : List (Grid a)) -> + Nat -> + Maybe (g : Grid a ** Fin (size g)) +locateChunk _ _ [] _ = Nothing +locateChunk size gap (g :: gs) i = case natToFin i (size g) of + Just j => Just (g ** j) + Nothing => let sz = size g + in case i < sz + gap of + True => Nothing + False => locateChunk size gap gs (minus i (sz + gap)) + +||| Place grids side-by-side with `gap` blank columns between each pair. The +||| row dimension is the max of the children's heights; shorter children are +||| padded on the missing rows. +public export +besideAllGap : (padValue : a) -> (gap : Nat) -> List (Grid a) -> Grid a +besideAllGap _ _ [] = emptyGrid +besideAllGap _ _ [g] = g +besideAllGap pad gap grids@(_ :: _) = + let maxHeight = max (gridHeight <$> grids) + sumWidths = List.sum (gridWidth <$> grids) + gap * pred (length grids) + in mkGrid maxHeight sumWidths $ \i, j => fromMaybe pad $ do + (g ** j') <- locateChunk gridWidth gap grids (finToNat j) + i' <- natToFin (finToNat i) (gridHeight g) + pure (gridIndex g i' j') + +||| Place grids side-by-side with no gap. +public export +besideAll : (padValue : a) -> List (Grid a) -> Grid a +besideAll pad = besideAllGap pad 0 + +||| Stack grids vertically, inserting `sep` blank rows between successive items. +||| The column dimension is the max of the children's widths; narrower children +||| are padded on the missing columns. +public export +aboveAllSep : (padValue : a) -> (sep : Nat) -> List (Grid a) -> Grid a +aboveAllSep _ _ [] = emptyGrid +aboveAllSep _ _ [g] = g +aboveAllSep pad sep grids@(_ :: _) = + let sumHeights = List.sum (gridHeight <$> grids) + sep * pred (length grids) + maxWidth = max (gridWidth <$> grids) + in mkGrid sumHeights maxWidth $ \i, j => fromMaybe pad $ do + (g ** i') <- locateChunk gridHeight sep grids (finToNat i) + j' <- natToFin (finToNat j) (gridWidth g) + pure (gridIndex g i' j') + +||| Stack grids vertically with no separator. +public export +aboveAll : (padValue : a) -> List (Grid a) -> Grid a +aboveAll pad = aboveAllSep pad 0 + +||| Left-pad a grid with `padCharacter` columns to reach total width `w`. +||| No-op if `gridWidth g >= w`. +public export +padGridLeft : (w : Nat) -> Grid Char -> Grid Char +padGridLeft w g = besideAll padCharacter + [mkGrid (gridHeight g) (w `minus` gridWidth g) (\_, _ => padCharacter), g] + +{------------------------------------------------------------------------------- +Borders +-------------------------------------------------------------------------------} + +||| Given a `k : Nat`, models a three-way classification of `Fin (S (S k))` into +||| first, last, or middle. +data Side : (k : Nat) -> Type where + AtStart : Side k + AtEnd : Side k + AtMid : Fin k -> Side k + +side : {k : Nat} -> Fin (S (S k)) -> Side k +side FZ = AtStart +side (FS x) = maybe AtEnd AtMid (strengthen x) + +||| Wrap a grid in a 1-character border specified by `box`. +public export +addBorderToGrid : (box : Box) => Grid Char -> Grid Char +addBorderToGrid g = mkGrid (2 + gridHeight g) (2 + gridWidth g) $ \i, j => + case (side i, side j) of + (AtMid i', AtMid j') => gridIndex g i' j' + (AtMid _, _) => box.vertical + (_, AtMid _) => box.horizontal + (AtStart, AtStart) => box.topLeft + (AtStart, AtEnd) => box.topRight + (AtEnd, AtStart) => box.bottomLeft + (AtEnd, AtEnd) => box.bottomRight + +||| Wrap a grid in a border if it has more than one row +public export +wrapNonEmpty : Box => Grid Char -> Grid Char +wrapNonEmpty g = applyWhen (gridHeight g > 1) addBorderToGrid g + +||| Given a list of grids, make a function that wraps a grid in a border if +||| any of the grids in the list has more than one row +public export +wrapAllIfAnyNonEmpty : Box => + List (Grid Char) -> (Grid Char -> Grid Char) +wrapAllIfAnyNonEmpty grids = + applyWhen (any (\g => gridHeight g > 1) grids) addBorderToGrid + + +{------------------------------------------------------------------------------- +List/pair brackets and joins +-------------------------------------------------------------------------------} + +||| Join a list of grids horizontally with a separator between them +public export +horizontalListJoin : (listSyntax : ListSyntax) => List (Grid Char) -> Grid Char +horizontalListJoin [] = emptyGrid +horizontalListJoin [g] = g +horizontalListJoin gs = besideAll padCharacter (intersperse sep gs) + where sep = besideAll padCharacter + [singleValue listSyntax.separator, singleValue padCharacter] + +||| Wrap a list of grids vertically in `[` ... `]` brackets with `,` markers in +||| front of subsequent items, as follows: +||| +||| [item1 +||| ,item2 +||| ,item3] +||| +||| `nSep` is the number of blank rows between items: +public export +wrapListBrackets : (listSyntax : ListSyntax) => + (nSep : Nat) -> List (Grid Char) -> Grid Char +wrapListBrackets _ [] = besideAll padCharacter + [singleValue listSyntax.left, singleValue listSyntax.right] +wrapListBrackets nSep (x :: xs) = + let body = aboveAllSep padCharacter nSep (x :: xs) + leftCol = aboveAll padCharacter $ + topMarkerCol listSyntax.left (gridHeight x) :: + concatMap (\g => [ blankCol nSep + , topMarkerCol listSyntax.separator (gridHeight g) + ]) xs + rightCol = bottomMarkerCol listSyntax.right (gridHeight body) + in besideAll padCharacter [leftCol, body, rightCol] + + +{------------------------------------------------------------------------------- +NumPy-style row wrapping for long innermost rows +-------------------------------------------------------------------------------} + +||| Render a list of equal-width cells as `[ c0 c1 ... cN ]`, breaking the +||| output onto continuation lines when the single-line layout would exceed +||| `lineBudget` columns. Continuation lines are indented by one space so +||| wrapped cells align under `c0`; the closing `]` sits next to the last cell +||| of the final chunk, so a shorter final chunk leaves trailing whitespace on +||| its line (NumPy-style). +||| +||| The single-line layout falls out automatically when `cellsPerLine >= length +||| children`, so no separate "flat" code path is needed. +public export +wrappedInnerRow : (listSyntax : ListSyntax) => + (lineBudget, gap : Nat) -> List (Grid Char) -> Grid Char +wrappedInnerRow _ _ [] = besideAll padCharacter + [singleValue listSyntax.left, singleValue listSyntax.right] +wrappedInnerRow lineBudget gap children@(c :: _) = + let cellsPerLine = max 1 $ (lineBudget `minus` 2 + gap) `div` (gridWidth c + gap) + chunks = chunksOf cellsPerLine children + nChunks = length chunks + chunkRow : Nat -> List (Grid Char) -> Grid Char + chunkRow i cs = + let leftC = if i == 0 then listSyntax.left else padCharacter + rightC = if S i == nChunks then listSyntax.right else padCharacter + in besideAll padCharacter + [ singleValue leftC + , besideAllGap padCharacter gap cs + , singleValue rightC ] + in aboveAll padCharacter (mapWithIndex chunkRow chunks) + + +{------------------------------------------------------------------------------- +Display2D interface: rendering types as 2D character grids +-------------------------------------------------------------------------------} + +||| Display a type as a 2D grid of characters. +||| For `a` being an extension of a container, this can be expresse in terms of +||| a container morphism... if we use `Maybe ` +public export +interface Display2D (0 a : Type) where + constructor MkDisplay2D + display2D : a -> Grid Char + +public export +Display2D (Grid Char) where + display2D = id + +||| One-row grid from a `Show` instance. +public export +display2DFromShow : Show a => a -> Grid Char +display2DFromShow x = rowGrid (unpack (show x)) + +||| One-row grid from a `ScientificDisplay` instance. +public export +display2DFromSci : ScientificDisplay a => a -> Grid Char +display2DFromSci x = rowGrid (unpack (showSci x)) + +public export Display2D Int where display2D = display2DFromShow +public export Display2D Integer where display2D = display2DFromSci +public export Display2D Double where display2D = display2DFromSci +public export Display2D Nat where display2D = display2DFromSci +public export Display2D Bool where display2D = display2DFromShow +public export Display2D () where display2D = display2DFromShow +public export Display2D Char where display2D = singleValue +public export Display2D String where display2D s = rowGrid (unpack s) + + +{------------------------------------------------------------------------------- +Display2D instances for container extensions +-------------------------------------------------------------------------------} + +public export +Display2D a => Display2D (Scalar' a) where + display2D (() <| index) = display2D (index ()) + +public export +Display2D a => Display2D (Pair' a) where + display2D (() <| index) = besideAll padCharacter + [ singleValue (left AsciiPairSyntax) + , display2D (index False) + , singleValue (separator AsciiPairSyntax) + , display2D (index True) + , singleValue (right AsciiPairSyntax) ] + +public export +Display2D a => Display2D (List' a) where + display2D (_ <| index) = besideAll padCharacter + [ singleValue (left AsciiListSyntax) + , horizontalListJoin {listSyntax = AsciiListSyntax} + (display2D <$> toList' (tabulate index)) + , singleValue (right AsciiListSyntax) ] + + +{------------------------------------------------------------------------------- +Tree helpers +-------------------------------------------------------------------------------} + +||| 2-column tree-branch prefix: first row is `connector` + `treeHorizontal`; +||| subsequent rows are `continuation` + `treeGap`. +treeBranchPrefix : (tree : Tree) => + (connector, continuation : Char) -> (height : Nat) -> Grid Char +treeBranchPrefix _ _ Z = emptyGrid +treeBranchPrefix conn cont (S k) = mkGrid (S k) 2 $ \i, j => + case (i, j) of + (FZ, FZ) => conn + (FZ, _ ) => tree.horizontal + (_ , FZ) => cont + (_ , _ ) => tree.gap + +||| Prepend a tree-branch prefix and a 1-column gap to a grid. +addBranch : Tree => + (connector, continuation : Char) -> Grid Char -> Grid Char +addBranch conn cont g = besideAll padCharacter + [treeBranchPrefix conn cont (gridHeight g), blankCol (gridHeight g), g] + +||| Full-width row between sibling subtrees: `vertical` then spaces. +treeSiblingGapRow : (tree : Tree) => (w : Nat) -> Grid Char +treeSiblingGapRow Z = emptyGrid +treeSiblingGapRow (S k) = mkGrid 1 (S k) $ \_, j => case j of + FZ => tree.vertical + _ => padCharacter + +||| Lay out a tree node: root above, then left/right subtrees with branches. +displayNodeWithBranches : (tree : Tree) => + (root, left, right : Grid Char) -> Grid Char +displayNodeWithBranches root left right = + let leftB = addBranch tree.branchMid tree.vertical left + rightB = addBranch tree.branchLast tree.gap right + top = aboveAll padCharacter + [root, treeSiblingGapRow (maximum (gridWidth root) (gridWidth leftB)), leftB] + in aboveAll padCharacter [top, treeSiblingGapRow (gridWidth top), rightB] + + +{------------------------------------------------------------------------------- +BinTree (values on both nodes and leaves) +-------------------------------------------------------------------------------} + +collectBinTreeValues : Display2D a => + BinTree' a -> + (List (Grid Char), List (Grid Char)) +collectBinTreeValues (LeafS <| index) = ([], [display2D (index AtLeaf)]) +collectBinTreeValues (NodeS l r <| index) = + case collectBinTreeValues (l <| (index . GoLeft)) of + (ln, ll) => case collectBinTreeValues (r <| (index . GoRight)) of + (rn, rl) => (display2D (index AtNode) :: (ln ++ rn), ll ++ rl) + +displayBinTreeWith : Display2D a => (tree : Tree) => + (nodeBox, leafBox : Grid Char -> Grid Char) -> + BinTree' a -> Grid Char +displayBinTreeWith _ leafBox (LeafS <| index) + = leafBox (display2D (index AtLeaf)) +displayBinTreeWith nodeBox leafBox (NodeS l r <| index) = + displayNodeWithBranches (nodeBox (display2D (index AtNode))) + (displayBinTreeWith nodeBox leafBox (l <| (index . GoLeft))) + (displayBinTreeWith nodeBox leafBox (r <| (index . GoRight))) + +displayBinTree : Display2D a => (tree : Tree) => (box : Box) => + BinTree' a -> Grid Char +displayBinTree t = + let (nodeEGs, leafEGs) = collectBinTreeValues t + in displayBinTreeWith (wrapAllIfAnyNonEmpty nodeEGs) + (wrapAllIfAnyNonEmpty leafEGs) t + +public export +Display2D a => Display2D (BinTree' a) where + display2D = displayBinTree {tree = SingleLineTree, box = DoubleLineBox} + + +{------------------------------------------------------------------------------- +BinTreeLeaf (values only on leaves) +-------------------------------------------------------------------------------} + +collectLeafValues : Display2D a => + BinTreeLeaf' a -> + List (Grid Char) +collectLeafValues (LeafS <| index) = [display2D (index AtLeaf)] +collectLeafValues (NodeS l r <| index) = + collectLeafValues (l <| index . GoLeft) ++ + collectLeafValues (r <| index . GoRight) + +displayBinTreeLeafWith : Display2D a => (tree : Tree) => + (leafBox : Grid Char -> Grid Char) -> + BinTreeLeaf' a -> Grid Char +displayBinTreeLeafWith box (LeafS <| index) = box (display2D (index AtLeaf)) +displayBinTreeLeafWith box (NodeS l r <| index) = + displayNodeWithBranches (singleValue tree.placeholder) + (displayBinTreeLeafWith box (l <| index . GoLeft)) + (displayBinTreeLeafWith box (r <| index . GoRight)) + +||| Uniform boxing for BinTreeLeaf: if any leaf is multi-line, box all +displayBinTreeLeaf : Display2D a => (tree : Tree) => (box : Box) => + BinTreeLeaf' a -> Grid Char +displayBinTreeLeaf t = + displayBinTreeLeafWith (wrapAllIfAnyNonEmpty (collectLeafValues t)) t + +public export +Display2D a => Display2D (Ext BinTreeLeaf a) where + display2D = displayBinTreeLeaf {tree = SingleLineTree, box = DoubleLineBox} + + +{------------------------------------------------------------------------------- +BinTreeNode (values only on nodes) +-------------------------------------------------------------------------------} + +collectNodeValues : Display2D a => + BinTreeNode' a -> List (Grid Char) +collectNodeValues (LeafS <| _) = [] +collectNodeValues (NodeS l r <| index) = + let leftValues = collectNodeValues (l <| index . GoLeft) + rightValues = collectNodeValues (r <| index . GoRight) + in display2D (index AtNode) :: (leftValues ++ rightValues) + +displayBinTreeNodeWith : Display2D a => (tree : Tree) => + (nodeBox : Grid Char -> Grid Char) -> + BinTreeNode' a -> Grid Char +displayBinTreeNodeWith _ (LeafS <| _) = singleValue tree.placeholder +displayBinTreeNodeWith box (NodeS l r <| index) = + displayNodeWithBranches (box (display2D (index AtNode))) + (displayBinTreeNodeWith box (l <| index . GoLeft)) + (displayBinTreeNodeWith box (r <| index . GoRight)) + +displayBinTreeNode : Display2D a => (tree : Tree) => (box : Box) => + BinTreeNode' a -> Grid Char +displayBinTreeNode t = + displayBinTreeNodeWith (wrapAllIfAnyNonEmpty (collectNodeValues t)) t + +public export +Display2D a => Display2D (Ext BinTreeNode a) where + display2D = displayBinTreeNode {tree = SingleLineTree, box = DoubleLineBox} + +public export +{n : Nat} -> Display2D a => Display2D (Ext (Vect n) a) where + display2D (() <| index) = display2D {a = Ext List a} (n <| index) + +public export +showViaDisplay2D : Display2D a => a -> String +showViaDisplay2D = showGrid . display2D + +public export +Display2D a => Show (Scalar' a) where + show = showViaDisplay2D + +public export +Display2D a => Show (Pair' a) where + show = showViaDisplay2D + +public export +Display2D a => Show (List' a) where + show = showViaDisplay2D + +public export +{n : Nat} -> Display2D a => Show (Vect' n a) where + show = showViaDisplay2D + +-- Technocally should not need assert_total here, but its hard to convince +-- the typechecker +public export +Display2D a => Show (BinTree' a) where + show t = assert_total $ showViaDisplay2D t + +public export +Display2D a => Show (BinTreeLeaf' a) where + show t = assert_total $ showViaDisplay2D t + +public export +Display2D a => Show (BinTreeNode' a) where + show t = assert_total $ showViaDisplay2D t \ No newline at end of file diff --git a/src/Data/Container/Base/Extension/Definition.idr b/src/Data/Container/Base/Extension/Definition.idr index bf8f8aa..1cb8052 100644 --- a/src/Data/Container/Base/Extension/Definition.idr +++ b/src/Data/Container/Base/Extension/Definition.idr @@ -1,6 +1,7 @@ module Data.Container.Base.Extension.Definition import Data.Container.Base.Object.Definition +import Data.Container.Base.Morphism.Definition import Misc @@ -29,19 +30,21 @@ Functor (Ext c) where ||| Composition of extensions is a functor public export -Functor ((Ext d) . (Ext c)) where +Functor (Ext d . Ext c) where map f e = (map f) <$> e -||| The `index` field of an extension defines a "getter" for a container -||| This is the container setter +||| Ext is a functor of type Cont -> [Type, Type] +||| On objects it maps a container to a polynomial functor +||| On morphisms it maps a dependent lens to a natural transformation +||| This is the action on morphisms public export -set : {0 c : Cont} -> InterfaceOnPositions c Eq => - (e : Ext c x) -> c.Pos (shapeExt e) -> x -> Ext c x -set {c=(s !> p)} @{MkI} (sh <| contentAt) i x - = sh <| updateAt contentAt (i, x) +extMap : c =%> d -> Ext c a -> Ext d a +extMap f (sh <| index) = let (y ** ky) = (%!) f sh + in y <| (index . ky) + namespace ExtProofs - ||| Map ing over an extension preserves its shape + ||| Mapping over an extension preserves its shape public export mapShapeExt : {0 c : Cont} -> {0 f : a -> b} -> diff --git a/src/Data/Container/Base/Extension/Instances.idr b/src/Data/Container/Base/Extension/Instances.idr index 244acb1..629e5b0 100644 --- a/src/Data/Container/Base/Extension/Instances.idr +++ b/src/Data/Container/Base/Extension/Instances.idr @@ -4,8 +4,10 @@ import Data.DPair import Data.Vect import Data.Container.Base.Object.Definition -import Data.Container.Base.Object.Instances import Data.Container.Base.Extension.Definition +import Data.Container.Base.Properties.Definitions + +import Data.Container.Base.Object.Instances -- import Data.Functor.Naperian import Misc @@ -111,4 +113,13 @@ IsNaperian c => Applicative (Ext c) where ||| i.e. containers with a unit shape public export positionsCont : {0 c : Cont} -> {sh : c.Shp} -> Ext c (c.Pos sh) -positionsCont = sh <| id \ No newline at end of file +positionsCont = sh <| id + + +||| The `index` field of an extension defines a "getter" for a container +||| This is the container setter +public export +set : {0 c : Cont} -> InterfaceOnPositions c Eq => + (e : Ext c x) -> c.Pos (shapeExt e) -> x -> Ext c x +set {c=(s !> p)} @{MkI _} (sh <| contentAt) i x + = sh <| updateAt contentAt (i, x) diff --git a/src/Data/Container/Base/InstanceInterfaces.idr b/src/Data/Container/Base/InstanceInterfaces.idr index 730e129..7d21e82 100644 --- a/src/Data/Container/Base/InstanceInterfaces.idr +++ b/src/Data/Container/Base/InstanceInterfaces.idr @@ -8,12 +8,12 @@ import Data.Finite import Data.Container.Base.Object.Definition import Data.Container.Base.Extension.Definition -import Data.Container.Base.Concrete.Definition +import Data.Container.Base.Properties.Definitions import Data.Container.Base.Object.Instances import Data.Container.Base.Extension.Instances -import Data.Container.Base.Concrete.Instances +import Data.Container.Base.Properties.Instances -import Data.Tree +import Data.Trees import Data.Container.Base.TreeUtils import Data.Functor.Algebra import Misc @@ -21,29 +21,30 @@ import Misc %hide Prelude.toList +-- the idea is that this file will slowly be made obsolete as more and more +-- things are implemented in terms of containers + + ||| Any finite container (i.e. whose each set of positions is finite) can be ||| given an algebra instance simply by summing up all the concrete values public export algebraFinite : - (c : Cont) -> (isFinite : IsFinite c) => + (0 c : Cont) -> (isFinite : IsFinite c) => (0 a : Type) -> Num a => Algebra (Ext c) a -algebraFinite c {isFinite = MkI @{p}} _ +algebraFinite c {isFinite = MkI p} _ = MkAlgebra $ \(shp <| content) => reduce $ values @{p shp} <&> content + namespace VectInstances public export {n : Nat} -> Eq x => Eq (Vect' n x) where v == v' = (toVect v) == (toVect v') - public export - {n : Nat} -> Show x => Show (Vect' n x) where - show v = show (toVect v) + -- public export + -- {n : Nat} -> Show x => Show (Vect' n x) where + -- show v = show (toVect v) - public export - {n : Nat} -> Foldable (Vect' n) where - foldr f z v = foldr f z (toVect v) - public export {n : Nat} -> Num a => Algebra (Vect' n) a where reduce v = reduce (toVect v) @@ -77,14 +78,10 @@ namespace ListInstances Eq a => Eq (List' a) where l == l' = assert_total ((toList l) == (toList l')) - ||| Is there a different way to convince Idris' totality checker? - public export - Show a => Show (List' a) where - show x = assert_total (show (toList x)) - - public export - Foldable List' where - foldr f z v = foldr f z (toList v) + -- ||| Is there a different way to convince Idris' totality checker? + -- public export + -- Show a => Show (List' a) where + -- show x = assert_total (show (toList x)) public export Num a => Algebra List' a where @@ -111,10 +108,10 @@ namespace BinTreeInstances Eq a => Eq (BinTree' a) where t == t' = assert_total (toBinTreeSame t == toBinTreeSame t') - ||| Is there a different way to convince Idris' totality checker? - public export - Show a => Show (BinTree' a) where - show = assert_total (show . toBinTreeSame) + -- ||| Is there a different way to convince Idris' totality checker? + -- public export + -- Show a => Show (BinTree' a) where + -- show = assert_total (show . toBinTreeSame) ||| Summing up nodes and leaves of the tree given by the Num a structure public export @@ -132,21 +129,16 @@ namespace BinTreeLeafInstances Eq a => Eq (BinTreeLeaf' a) where t == t' = assert_total (toBinTreeLeaf t == toBinTreeLeaf t') - ||| Is there a different way to convince Idris' totality checker? - public export - Show a => Show (BinTreeLeaf' a) where - show = assert_total (show . toBinTreeLeaf) + -- ||| Is there a different way to convince Idris' totality checker? + -- public export + -- Show a => Show (BinTreeLeaf' a) where + -- show = assert_total (show . toBinTreeLeaf) ||| Summing up leaves of the tree given by the Num a structure public export Num a => Algebra BinTreeLeaf' a where reduce = reduce {f=BinTreeLeaf} . toBinTreeLeaf - ||| Requires making a choice on which subtree to process first - public export - Foldable BinTreeLeaf' where - foldr f z t = foldr {t=BinTreeLeaf} f z (toBinTreeLeaf t) - namespace BinTreeNodeInstances ||| Is there a different way to convince Idris' totality checker? @@ -154,13 +146,12 @@ namespace BinTreeNodeInstances Eq a => Eq (BinTreeNode' a) where t == t' = assert_total (toBinTreeNode t == toBinTreeNode t') - ||| Is there a different way to convince Idris' totality checker? - public export - Show a => Show (BinTreeNode' a) where - show = assert_total (show . toBinTreeNode) + -- ||| Is there a different way to convince Idris' totality checker? + -- public export + -- Show a => Show (BinTreeNode' a) where + -- show = assert_total (show . toBinTreeNode) ||| Summing up nodes of the tree given by the Num a structure public export Num a => Algebra BinTreeNode' a where - reduce = reduce {f=BinTreeNode} . toBinTreeNode - + reduce = reduce {f=BinTreeNode} . toBinTreeNode \ No newline at end of file diff --git a/src/Data/Container/Base/Instances.idr b/src/Data/Container/Base/Instances.idr index 39cc971..dabda99 100644 --- a/src/Data/Container/Base/Instances.idr +++ b/src/Data/Container/Base/Instances.idr @@ -4,7 +4,7 @@ module Data.Container.Base.Instances import public Data.Container.Base.Object.Instances import public Data.Container.Base.Extension.Instances import public Data.Container.Base.Morphism.Instances -import public Data.Container.Base.Concrete.Instances +import public Data.Container.Base.Properties.Instances import public Data.Container.Base.InstanceInterfaces import public Data.Container.Base.Product.Interfaces diff --git a/src/Data/Container/Base/Morphism/Definition.idr b/src/Data/Container/Base/Morphism/Definition.idr index 5b43872..fc36da7 100644 --- a/src/Data/Container/Base/Morphism/Definition.idr +++ b/src/Data/Container/Base/Morphism/Definition.idr @@ -138,6 +138,7 @@ namespace Cartesian (:&) : c1 =:> c2 -> c1 =&> c2 (:&) (!: f) = !& \x => let (y ** ky) = f x in (y ** forward ky) + ||| Similar to the extension of a container. Following some ideas in ||| Diegetic open games (https://arxiv.org/abs/2206.12338) ||| Is this recovered via container composition when `r` is a some container? diff --git a/src/Data/Container/Base/Morphism/Instances.idr b/src/Data/Container/Base/Morphism/Instances.idr index f026afd..9435c2c 100644 --- a/src/Data/Container/Base/Morphism/Instances.idr +++ b/src/Data/Container/Base/Morphism/Instances.idr @@ -3,12 +3,15 @@ module Data.Container.Base.Morphism.Instances import Data.Fin import Data.Fin.Split import Data.Vect +import Data.List.Elem import Data.List.Quantifiers import Data.Container.Base.Object.Definition import Data.Container.Base.Morphism.Definition import Data.Container.Base.Extension.Definition +import Data.Container.Base.Properties.Definitions import Data.Container.Base.Product.Definitions + import Data.Container.Base.Object.Instances import Data.Container.Base.Quantifiers @@ -21,46 +24,46 @@ import Data.Num import Data.Layout import Misc -||| "State" as defined in https://arxiv.org/abs/2403.13001 and open games -||| Given a shape of any container, state can be defined -public export -State : Cont -> Type -State c = Scalar =%> c - -public export -Costate : Cont -> Type -Costate c = c =%> Scalar +namespace State + ||| "State" as defined in https://arxiv.org/abs/2403.13001 and open games + ||| Given a shape of any container, state can be defined + public export + State : Cont -> Type + State c = Scalar =%> c -public export -toState : {0 c : Cont} -> (x : c.Shp) -> State c -toState x = !% \() => (x ** \_ => ()) + public export + toState : {0 c : Cont} -> (x : c.Shp) -> State c + toState x = !% \() => (x ** \_ => ()) -public export -fromState : {0 c : Cont} -> - State c -> - c.Shp -fromState f = f.fwd () + public export + fromState : {0 c : Cont} -> + State c -> + c.Shp + fromState f = f.fwd () -public export -fromCostate : {0 c : Cont} -> - Costate c -> - (x : c.Shp) -> c.Pos x -fromCostate f x = f.bwd x () +namespace Costate + public export + Costate : Cont -> Type + Costate c = c =%> Scalar -public export -toCostate : {0 c : Cont} -> - ((x : c.Shp) -> c.Pos x) -> - Costate c -toCostate s = !% \x => (() ** \() => s x) + public export + fromCostate : {0 c : Cont} -> + Costate c -> + (x : c.Shp) -> c.Pos x + fromCostate f x = f.bwd x () + + public export + toCostate : {0 c : Cont} -> + ((x : c.Shp) -> c.Pos x) -> + Costate c + toCostate s = !% \x => (() ** \() => s x) public export -fromNapCostateToState : {0 c : Cont} -> - Costate (Nap c.Shp) -> State c +fromNapCostateToState : Costate (Nap c.Shp) -> State c fromNapCostateToState f = toState (f.bwd () ()) public export -fromStateToNapCostate : {0 c : Cont} -> - State c -> Costate (Nap c.Shp) +fromStateToNapCostate : State c -> Costate (Nap c.Shp) fromStateToNapCostate f = toCostate f.fwd public export @@ -68,115 +71,157 @@ pushDown : Cont -> Cont pushDown c = Const2 Unit c.Shp public export -pushIntoContinuation : {d, p, l : Cont} -> +pushIntoContinuation : {0 d, p, l : Cont} -> (d >< p =%> l) -> (p =%> (pushDown d) >@ l) pushIntoContinuation f = !% \p => (() <| \d => f.fwd (d, p) ** \(d ** l') => snd $ f.bwd (d, p) l') +namespace CategoricalProduct + public export + terminal : c =%> UnitCont + terminal = !% \_ => (() ** absurd) + + namespace HancockTensorProduct public export - leftUnit : (Scalar >< c) =%> c + leftUnit : Scalar >< c =%> c leftUnit = !% \((), s) => (s ** \p => ((), p)) public export - rightUnit : (c >< Scalar) =%> c + rightUnit : c >< Scalar =%> c rightUnit = !% \(x, ()) => (x ** \x' => (x', ())) public export - leftUnitInv : c =%> (Scalar >< c) + leftUnitInv : c =%> Scalar >< c leftUnitInv = !% \x => (((), x) ** \((), x') => x') public export - rightUnitInv : c =%> (c >< Scalar) + rightUnitInv : c =%> c >< Scalar rightUnitInv = !% \x => ((x, ()) ** \(x', ()) => x') public export - assocL : ((a >< b) >< c) =%> (a >< (b >< c)) + assocL : (a >< b) >< c =%> a >< (b >< c) assocL = !% \((a, b), c) => ((a, (b, c)) ** \(a', (b', c')) => ((a', b'), c')) public export - assocR : (a >< (b >< c)) =%> ((a >< b) >< c) + assocR : a >< (b >< c) =%> (a >< b) >< c assocR = !% \(a, (b, c)) => (((a, b), c) ** \((a', b'), c') => (a', (b', c'))) public export - swap : (a >< b) =%> (b >< a) + swap : a >< b =%> b >< a swap = !% \(a, b) => ((b, a) ** \(b', a') => (a', b')) namespace CompositionProduct public export - leftUnit : (Scalar >@ c) =%> c + leftUnit : Scalar >@ c =%> c leftUnit = !% \(() <| cShp) => (cShp () ** \c' => (() ** c')) public export - rightUnit : (c >@ Scalar) =%> c + rightUnit : c >@ Scalar =%> c rightUnit = !% \(s <| _) => (s ** \cp => (cp ** ())) public export - leftUnitInv : c =%> (Scalar >@ c) + leftUnitInv : c =%> Scalar >@ c leftUnitInv = !% \x => (() <| (\_ => x) ** \(() ** c') => c') public export - rightUnitInv : c =%> (c >@ Scalar) + rightUnitInv : c =%> c >@ Scalar rightUnitInv = !% \s => (s <| const () ** fst) namespace Coproduct public export - elim : {c : Cont} -> - (c >+< c) =%> c + elim : c >+< c =%> c elim = !% \case - (Left x) => (x ** id) - (Right y) => (y ** id) + Left x => (x ** id) + Right y => (y ** id) + + public export + initial : Empty =%> c + initial = !% absurd -||| Interaction between composition and tensor product -public export -duoidal : ((c >@ d) >< (e >@ f)) =%> ((c >< e) >@ (d >< f)) -duoidal = !% \((sc <| idxC), (se <| idxE)) => - ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) ** - \((cp, ep) ** (dp, fp)) => ((cp ** dp), (ep ** fp))) -||| Specific distributive law we need -public export -distribute : ((c >< e) =%> s) -> - ((c >< (e >@ g)) =%> (s >@ g)) -distribute f = (rightUnitInv >< id {a=e >@ g}) - %>> duoidal {d = Scalar} - %>> (f >@ leftUnit) - -||| Ext is a functor of type Cont -> [Type, Type] -||| On objects it maps a container to a polynomial functor -||| On morphisms it maps a dependent lens to a natural transformation -||| This is the action on morphisms -public export -extMap : {0 c, d : Cont} -> - c =%> d -> - Ext c a -> Ext d a -extMap f (sh <| index) = let (y ** ky) = (%!) f sh - in y <| (index . ky) +namespace CartesianClosure + ||| The following is the proof that for any container `c` there is an + ||| isomorphism in `Cont` between `c` and `CartesianClosure UnitCont c` + ||| This holds in any monoidal closed category: `X ≅ [I, X]` + namespace StateIsomorphismProof + stateToCartClosureFw : c =%> (CartesianClosure UnitCont c) + stateToCartClosureFw = !% \cShp => (!% \() => (cShp ** \_ => Nothing) + ** \(() ** cPos ** ItIsNothing) => cPos) + + stateToCartClosureBw : CartesianClosure UnitCont c =%> c + stateToCartClosureBw = !% \l => (l.fwd () ** \cPos => + (() ** cPos ** maybeVoidIsNothing (l.bwd () cPos))) + + +||| For a overview of this interaction from the categorical perspective, see +||| the Poly book (https://arxiv.org/abs/2312.00990) (Section 6.3.4) +namespace CompositionTensorInteraction + ||| Interaction between composition and tensor product + ||| Swaps the operations, and middle two containers + ||| Not an isomorphism! + public export + duoidal : (c >@ d) >< (e >@ f) =%> (c >< e) >@ (d >< f) + duoidal = !% \((sc <| idxC), (se <| idxE)) => + ((sc, se) <| \(cp, ep) => (idxC cp, idxE ep) ** + \((cp, ep) ** (dp, fp)) => ((cp ** dp), (ep ** fp))) + + ||| Tensor product embeds into composition + ||| A special case of `duoidal` + public export + tensorToComp : c >< f =%> c >@ f + tensorToComp = (rightUnitInv >< leftUnitInv) + %>> duoidal {d=Scalar,e=Scalar} + %>> (rightUnit >@ leftUnit) + + ||| Going the other way is impossible without any constraints + ||| Two possibilities on constraints (this, and `compToTensor2`) + public export + compToTensor : IsNaperian d => + (c >@ d) =%> (c >< d) + compToTensor @{(MkIsNaperian dPos)} = !% \(cShp <| content) => + ((cShp,()) ** \(cPos, dPos) => (cPos ** dPos)) + + public export + compToTensor2 : IsFlat c => + (c >@ d) =%> (c >< d) + compToTensor2 @{(ItIsFlat cShp)} = !% \(cShp <| dShp) => + ((cShp, dShp ()) ** \((), dPos') => (() ** dPos')) + + ||| Specific distributive law we need + public export + distribute : (c >< e) =%> s -> + c >< (e >@ g) =%> s >@ g + distribute f = (rightUnitInv >< id {a=e >@ g}) + %>> duoidal {d = Scalar} + %>> (f >@ leftUnit) ||| Wraps a dependent lens `c =%> d` ||| into one of type `c >@ Scalar =%> d >@ Scalar` ||| Needed because `c >@ Scalar` isn't automatically reduced to `c` public export -wrapIntoVector : {c, d : Cont} -> - c =%> d -> +wrapIntoVector : c =%> d -> Tensor [c] =%> Tensor [d] -wrapIntoVector (!% f) = - !% \e => let (y ** ky) = f (shapeExt e) - in (y <| \_ => () ** \(cp ** ()) => (ky cp ** ())) +wrapIntoVector f = rightUnit %>> f %>> rightUnitInv + +public export +wrapIntoMatrix : (c >@ c') =%> (d >@ d') -> + Tensor [c, c'] =%> Tensor [d, d'] +wrapIntoMatrix f = (id >@ rightUnit) + %>> f + %>> (id >@ rightUnitInv) ||| Wraps a dependent lens `c =%> d` ||| into one of type `c >< Scalar =%> d >< Scalar` ||| Needed because `c >< Scalar` isn't automatically reduced to `c` public export -wrapIntoVectorHancock : {c, d : Cont} -> - c =%> d -> +wrapIntoVectorHancock : c =%> d -> HancockTensor [c] =%> HancockTensor [d] -wrapIntoVectorHancock f = !% \(x, ()) => - ((f.fwd x, ()) ** \(y', ()) => (f.bwd x y', ())) +wrapIntoVectorHancock f = rightUnit %>> f %>> rightUnitInv namespace CubicalHelpers ||| Helper function allowing `shape` in `cubicalShape` to have zero annotation @@ -246,18 +291,31 @@ reshape lo = flattenCubical lo namespace Transpose public export - transposeLens : IsNaperian c => IsNaperian d => (c >@ d) =%> (d >@ c) + transposeLens : IsNaperian c => IsNaperian d => c >@ d =%> d >@ c transposeLens @{MkIsNaperian _} @{MkIsNaperian _} = !% \(() <| _) => (() <| (\_ => ()) ** \(dInd ** cInd) => (cInd ** dInd)) - ||| This and the above function should be one and the same, up to rebracketing public export transpose : IsNaperian c => IsNaperian d => Tensor [c, d] =%> Tensor [d, c] - transpose @{MkIsNaperian _} @{MkIsNaperian _} = !% \(() <| _) => - (() <| (\_ => () <| (\_ => ())) ** \(dInd ** cInd ** ()) => - (cInd ** (dInd ** ()))) + transpose @{MkIsNaperian _} @{MkIsNaperian _} = wrapIntoMatrix transposeLens + + -- ||| experiment, does this work? + -- public export + -- transposeMiddle : IsNaperian c => IsNaperian e => + -- Tensor [c, e, d] =%> + + --||| Transpose a given element to the front of the shape + --public export + --transposeToFront : (shape : List Cont) -> + -- (c : Cont) -> + -- (elem : Elem c shape) => + -- All IsNaperian (dropAfterElem shape elem) => + -- Tensor shape =%> Tensor (c :: dropElem shape elem) + --transposeToFront (_ :: xs) c @{Here} @{allNap} = ?transposeToFront_rhs_0 + --transposeToFront (y :: xs) c @{(There x)} @{allNap} = ?transposeToFront_rhs_1 + ||| Functionality for transforming a tensor into a hancock tensor namespace TransformIntoHancockTensor public export @@ -303,10 +361,15 @@ namespace TransformIntoHancockTensor let (_ ** recBack) = (%!) transformToHancock (content p) in (p ** recBack $ replace {p = id} hancockTensorPosEq restPos)) - -public export -EmptyExtEq : {0 c : Cont} -> IsNaperian c => Ext c Unit = Unit -EmptyExtEq @{(MkIsNaperian pos)} = believe_me () -- what does wrong if we do this + public export + transformFromHancock : {shape : List Cont} -> + All IsNaperian shape => + HancockTensor shape =%> Tensor shape + transformFromHancock {shape = []} = id + transformFromHancock {shape = (Nap s :: ss)} @{((MkIsNaperian s) :: _)} + = !% \((), hShp) => + let (tShp ** recBack) = (%!) transformFromHancock hShp + in (() <| (\_ => tShp) ** \(p ** restPos) => (p, recBack restPos)) @@ -401,6 +464,20 @@ namespace BinTreeNode -- _ | Right FZ = ?whn -- _ | Right (FS g) = ?whr +namespace BinTreeLeaf + public export + inorderBackward : (b : BinTreeShape) -> + Fin (numLeaves b) -> + BinTreePosLeaf b + inorderBackward LeafS 0 = AtLeaf + inorderBackward (NodeS lt rt) i with (strengthenN {m=numLeaves lt} i) + _ | (Left indLeft) = GoLeft (inorderBackward lt indLeft) + _ | (Right indRight) = GoRight (inorderBackward rt indRight) + + public export + inorder : BinTreeLeaf =%> List + inorder = !% \b => (numLeaves b ** inorderBackward b) + -- public export -- traverseLeaf : (x : BinTreeShape) -> FinBinTreeLeaf x -> Fin (numLeaves x) -- traverseLeaf LeafS Done = FZ @@ -408,6 +485,10 @@ namespace BinTreeNode -- traverseLeaf (NodeS lt rt) (GoRight x) = shift (numLeaves lt) (traverseLeaf rt x) -- +public export +vectToList : {n : Nat} -> Vect n =%> List +vectToList = !% \() => (n ** id) + public export maybeToList : Maybe =%> List maybeToList = !% \b => case b of diff --git a/src/Data/Container/Base/Object/Definition.idr b/src/Data/Container/Base/Object/Definition.idr index 085f594..7b48d47 100644 --- a/src/Data/Container/Base/Object/Definition.idr +++ b/src/Data/Container/Base/Object/Definition.idr @@ -1,8 +1,5 @@ module Data.Container.Base.Object.Definition -import Data.Fin -import Data.Finite - ||| Containers capture the idea that datatypes consist of groups of memory ||| locations where data can be stored. Locations for a particular group are ||| referred to as 'positions' and a particular group is referred to as a @@ -17,90 +14,4 @@ record Cont where export typebind infixr 0 !> -%name Cont c, c', c'' - -||| Constant container, one where positions do not depend on shapes -public export -Const2 : Type -> Type -> Cont -Const2 a b = (_ : a) !> b - -||| Constant container, one where positions do not depend on shapes -public export -Const : Type -> Cont -Const a = Const2 a a - -||| Naperian container: a constant container with a single shape -public export -Nap : Type -> Cont -Nap b = Const2 Unit b - -||| Convenience datatype for storing the data that a container `c` has an -||| interface `i` on its positions -||| TODO does the argument of MkI need to be auto implicit? -public export -data InterfaceOnPositions : (c : Cont) -> (i : Type -> Type) -> Type where - ||| For every shape `s` the set of positions `c.Pos s` has that interface - MkI : (p : (s : c.Shp) -> i (c.Pos s)) => - InterfaceOnPositions c i - -||| A container is finite when for every shape the set of positions is finite. -||| Examples: vectors, lists, but also finite binary trees. -||| Note, provision of a finite instance for trees requires a choice of a tree -||| traversal. (All of these choices isomorphic, but are necessary to make) -public export -IsFinite : Cont -> Type -IsFinite c = InterfaceOnPositions c Finite - -||| A container is Naperian when the set of shapes is `Unit`, i.e. when it -||| contains only one set of positions. -||| Examples: Scalar, UnitCont, Pair, Vect n, Stream. -||| Notably, Naperian does not imply Finite, as the `Stream` example shows. -public export -data IsNaperian : Cont -> Type where - MkIsNaperian : (pos : Type) -> IsNaperian (Nap pos) - -public export -LogHelper : IsNaperian c => Type -LogHelper @{MkIsNaperian pos} = pos - -public export -Log : (0 c : Cont) -> IsNaperian c => Type -Log _ @{MkIsNaperian pos} = pos - --- ||| If we have a Naperian container, we can always produce (the unique) --- ||| inhabitant of it shape --- public export --- naperianShape : IsNaperian c => c.Shp --- naperianShape @{(MkIsNaperian pos)} = () --- -public export -naperianPosEq : IsNaperian c => {0 x, y : c.Shp} -> c.Pos x = c.Pos y -naperianPosEq @{MkIsNaperian _} = Refl - -||| A container is cubical whenever it is Finite and Naperian -||| Effectively, captures `Vect n` containers, up to isomorphism -||| Examples: for any `n : Nat`, `Vect n`. Those are all the examples, up to -||| isomorphism. Notably, this also includes a container whose unique set of -||| positions is the set of positions of a binary tree of a particular shape. -||| This is isomorphic to the `Vect k` container, for some `k`, assuming a -||| choice of tree traversal (though all of them yield the same `k`). Here -||| `k` corresponds to the number of positions in that binary tree -public export -data IsCubical : Cont -> Type where - MkIsCubical : (n : Nat) -> IsCubical (Nap (Fin n)) - -public export -dimHelper : IsCubical c -> Nat -dimHelper (MkIsCubical n) = n - -||| We call dimension the size of the set of positions of a finite container -public export -dim : (0 c : Cont) -> IsCubical c => Nat -dim _ @{ic} = dimHelper ic - - -||| Used in learning, where we want to know that the tangent space over a -||| particular parameter is equal to the parameter space itself -public export -data IsFlat : Cont -> Type where - MkIsFlat : (p : Type) -> IsFlat ((_ : p) !> p) +%name Cont c, c', c'' \ No newline at end of file diff --git a/src/Data/Container/Base/Object/Instances.idr b/src/Data/Container/Base/Object/Instances.idr index 6120627..61d1785 100644 --- a/src/Data/Container/Base/Object/Instances.idr +++ b/src/Data/Container/Base/Object/Instances.idr @@ -8,6 +8,42 @@ import Data.Container.Base.Product.Definitions import Data.Container.Base.TreeUtils import Control.Monad.Distribution + +{------------------------------------------------------------------------------- +This file defines a number of different containers +Some of them are possible to express in terms of each other, but we opt to define all of them directly +-------------------------------------------------------------------------------} + +||| Constant (non-dependent) container: positions do not depend on shapes +||| As a polynomial functor: F(X) = aX^b +public export +Const2 : Type -> Type -> Cont +Const2 a b = (_ : a) !> b + +||| Constant container whose shapes and positions coincide +||| As a polynomial functor: F(X) = aX^a +public export +Const : Type -> Cont +Const a = Const2 a a + +||| Naperian container: a constant container with a single shape +||| As a polynomial functor: F(X) = X^b +public export +Nap : Type -> Cont +Nap b = Const2 Unit b + +||| Flat container: a constant container with a single position +||| As a polynomial functor: F(X) = aX +public export +Flat : Type -> Cont +Flat a = Const2 a Unit + +||| Sharp container: a constant container without any positions +||| As a polynomial functor: F(X) = a +public export +Sharp : Type -> Cont +Sharp a = Const2 a Void + ||| Empty container, isomorphic to Void ||| As a polynomial functor: F(X) = 0 ||| Initial container @@ -70,6 +106,12 @@ public export Vect : List .Shp -> Cont Vect n = (_ : Unit) !> Fin n +||| Grid, container of things arranged along two axes +||| As a polynomial functor: F(X) = X^(hw) +public export +Grid : (List .Shp, List .Shp) -> Cont +Grid (h, w) = (Vect h) >< (Vect w) + ||| Container of an infinite number of things ||| As a polynomial functor: F(X) = X^Nat public export @@ -112,8 +154,14 @@ public export CoproductTensor : List Cont -> Cont CoproductTensor = foldr (>+<) Empty +||| Ignoring universe levels here +||| This should be the analogue of `Type : Type` +public export +ContUniverse : Cont +ContUniverse = (_ : (s : Type ** s -> Type)) !> Void -||| Can't believe this works? +||| Given a natural number `n`, this is a container whose shape represents a +||| distribution over `n` choices, and its position represents the choice made. public export Sample : Nat -> Cont Sample n = Const2 (Dist n) (Fin n) diff --git a/src/Data/Container/Base/Product/Definitions.idr b/src/Data/Container/Base/Product/Definitions.idr index 24f4eca..408e2e8 100644 --- a/src/Data/Container/Base/Product/Definitions.idr +++ b/src/Data/Container/Base/Product/Definitions.idr @@ -2,6 +2,7 @@ module Data.Container.Base.Product.Definitions import Data.DPair import Decidable.Equality +import Data.Either import Data.Vect import Data.List.Quantifiers import Data.Vect.Quantifiers @@ -9,6 +10,9 @@ import Data.Vect.Quantifiers import Data.Container.Base.Object.Definition import Data.Container.Base.Morphism.Definition import Data.Container.Base.Extension.Definition +import Data.Container.Base.Properties.Definitions + + import Data.Container.Base.Quantifiers import Control.Monad.Distribution @@ -24,6 +28,8 @@ public export infixr 3 <%> ||| Categorical product of containers ||| Monoid with UnitCont +||| It holds that `Ext (c1 >*< c2) a = (Ext c1) × (Ext c2)` where +||| `×` is the pointwise product of functors. namespace CategoricalProduct ||| Binary version of product public export @@ -42,10 +48,22 @@ namespace CategoricalProduct AllAny : Vect n Cont -> Cont AllAny xs = (shapes : All Shp xs) !> AnyPos shapes + ||| "Dependent categorical product": + ||| Dependent pair type for the categorical product of containers + ||| Given a container `s` and a family `p : s.Shp -> Cont`, + ||| form the container whose shapes are dependent pairs of shapes + ||| and a position is either a position of s or a position of p. + public export + DPairCart : (s : Cont) -> (p : s.Shp -> Cont) -> Cont + DPairCart s p = ((sShp ** pShp) : DPair s.Shp (Shp . p)) + !> Either (s.Pos sShp) ((p sShp).Pos pShp) + ||| Non-categorical product of containers, often also called ||| 'Hancock' (Scotland), 'Dirichlet' (Spivak), or 'Tensor product' (various) ||| Monoid with CUnit +||| It holds that `Ext (c1 >< c2) a = (Ext c1) ⊗ (Ext c2)` where +||| `⊗` is the day convolution product of functors. namespace HancockTensorProduct public export (><) : Cont -> Cont -> Cont @@ -70,19 +88,22 @@ namespace HancockTensorProduct (><) f g = !% \(c, d) => ((f.fwd c, g.fwd d) ** \(c', d') => (f.bwd c c', g.bwd d d')) - ||| Dependent Hancock (tensor) product of containers. - ||| This is the analogue of DPair for containers: - ||| Given a container `pc` and a family `qc : pc.Shp -> Cont`, + ||| "Dependent tensor product": + ||| Dependent pair type for the tensor product of containers + ||| Given a container `s` and a family `p : s.Shp -> Cont`, ||| form the container whose shapes are dependent pairs of shapes ||| and positions are pairs of positions. public export - DepHancockProduct : (pc : Cont) -> (qc : pc.Shp -> Cont) -> Cont - DepHancockProduct pc qc = - ((p ** q) : DPair pc.Shp (Shp . qc)) !> (pc.Pos p, (qc p).Pos q) - + DPairTensor : (s : Cont) -> (p : s.Shp -> Cont) -> Cont + DPairTensor s p = + ((sShp ** pShp) : DPair s.Shp (Shp . p)) !> (s.Pos sShp, (p sShp).Pos pShp) + +||| Coproduct of containers +||| Monoid with Empty +||| It holds that `Ext (c1 >+< c2) a = (Ext c1) + (Ext c2)` where +||| `+` is the pointwise product of functors. namespace CategoricalCoproduct - ||| Coproduct of containers - ||| Monoid with Empty + ||| Binary version of coproduct public export (>+<) : Cont -> Cont -> Cont c1 >+< c2 = (es : Either c1.Shp c2.Shp) !> either c1.Pos c2.Pos es @@ -106,13 +127,15 @@ namespace CompositionProduct public export (>@) : Cont -> Cont -> Cont c >@ d = (ex : Ext c d.Shp) !> - (cp : c.Pos (shapeExt ex) ** d.Pos (index ex cp)) + (DPair (c.Pos (shapeExt ex)) (d.Pos . index ex)) + -- (cp : c.Pos (shapeExt ex) ** d.Pos (index ex cp)) ||| Diagrammatic composition of containers, i.e. swapped order of composition public export (@>) : Cont -> Cont -> Cont c @> d = (ex : Ext d c.Shp) !> - (dp : d.Pos (shapeExt ex) ** c.Pos (index ex dp)) + (DPair (d.Pos (shapeExt ex)) (c.Pos . index ex)) + -- (dp : d.Pos (shapeExt ex) ** c.Pos (index ex dp)) namespace Morphism ||| Action on morphisms @@ -162,14 +185,55 @@ namespace MonoidalClosure uncurry f = !% \(x, y) => ((f.fwd x).fwd y ** \e' => (f.bwd x (y ** e'), (f.fwd x).bwd y e')) +||| If `f` is a monad, then `f -` is a comonad, and vice versa +public export +() : (f : Type -> Type) -> Cont -> Cont +() f c = (s : c.Shp) !> (f (c.Pos s)) + +public export infixr 9 + +namespace Morphism + public export + () : (f : Type -> Type) -> Functor f => + (c =%> d) -> + ((f c) =%> (f d)) + () f l = !% \x => (l.fwd x ** ((l.bwd x) <$>) ) + + public export infixr 9 + + ||| Closure with respect to the Cartesian product namespace CartesianClosure ||| From https://www.cs.ox.ac.uk/people/samuel.staton/papers/cie10.pdf public export CartesianClosure : Cont -> Cont -> Cont CartesianClosure c d - = (f : ((x : c.Shp) -> (y : d.Shp ** d.Pos y -> Maybe (c.Pos x)))) - !> (xx : c.Shp ** yy' : d.Pos (fst (f xx)) ** ?cartesianClosureImpl) + = (f : (Maybe c) =%> d) + !> (x : c.Shp ** y' : d.Pos (f.fwd x) ** IsNothing (f.bwd x y')) + + + public export + curry : c >*< d =%> e -> c =%> (CartesianClosure d e) + curry f = !% \x => (!% \y => (f.fwd (x, y) ** \z' => eitherToMaybe + (f.bwd (x, y) z')) ** bwPart) where + bwPart : {x : c.Shp} -> + (y : d.Shp ** z' : e.Pos (f.fwd (x, y)) ** IsNothing (eitherToMaybe (f.bwd (x, y) z'))) -> c.Pos x + bwPart (y ** z' ** isNothing) with (f.bwd (x, y) z') + bwPart (y ** z' ** ItIsNothing) | Left l = l + bwPart (y ** z' ** v) | Right r = absurd v + + public export + uncurry : c =%> (CartesianClosure d e) -> (c >*< d) =%> e + uncurry f = !% \(x, y) => ((f.fwd x).fwd y ** bwPart) where + bwPart : {x : c.Shp} -> {y : d.Shp} -> + e.Pos ((f.fwd x).fwd y) -> Either (c.Pos x) (d.Pos y) + bwPart z' with ((f.fwd x).bwd y z') proof p + bwPart z' | Nothing = Left $ f.bwd x (y ** z' ** rewrite p in ItIsNothing) + bwPart z' | Just r = Right r + + public export + apply : (CartesianClosure x y) >*< x =%> y + apply = uncurry {d=x} id -- Not exactly a product @@ -190,21 +254,6 @@ namespace Morphism -||| If `f` is a monad, then `f -` is a comonad, and vice versa -public export -() : (f : Type -> Type) -> Cont -> Cont -() f c = (s : c.Shp) !> (f (c.Pos s)) - -public export infixr 9 - -namespace Morphism - public export - () : (f : Type -> Type) -> Functor f => - (c =%> d) -> - ((f c) =%> (f d)) - () f l = !% \x => (l.fwd x ** ((l.bwd x) <$>) ) - - public export infixr 9 ||| BANG. List on positions, always has a monoid structure public export @@ -241,5 +290,5 @@ public export Deriv : (c : Cont) -> InterfaceOnPositions c DecEq => Cont -Deriv (shp !> pos) @{MkI} +Deriv (shp !> pos) @{MkI _} = ((s ** p) : DPair shp pos) !> (p' : pos s ** IsNo (decEq p p')) \ No newline at end of file diff --git a/src/Data/Container/Base/Product/InterfaceImplementations.idr b/src/Data/Container/Base/Product/InterfaceImplementations.idr index 5bbdd27..652830b 100644 --- a/src/Data/Container/Base/Product/InterfaceImplementations.idr +++ b/src/Data/Container/Base/Product/InterfaceImplementations.idr @@ -4,13 +4,13 @@ import Data.Container.Base.Object.Definition import Data.Container.Base.Extension.Definition import Data.Container.Base.Morphism.Definition import Data.Container.Base.Product.Definitions -import Data.Container.Base.Concrete.Definition +import Data.Container.Base.Properties.Definitions import Data.Container.Base.Object.Instances import Data.Container.Base.Extension.Instances import Data.Container.Base.Morphism.Instances import Data.Container.Base.Product.Interfaces -import Data.Container.Base.Concrete.Instances +import Data.Container.Base.Properties.Instances import Data.Container.Base.TreeUtils @@ -18,6 +18,7 @@ import Data.Fin.Split import Data.Layout import Data.Functor.Algebra +import Misc export TensorMonoid Maybe where @@ -26,8 +27,12 @@ TensorMonoid Maybe where True => ((), if b2 then bb else absurd bb) False => absurd bb) - --- TODO Either Applicative? +||| Chaining computation +||| Different from ordinary `Either` in that both variables are of the same type +export +TensorMonoid Either where + tensorN = toState True + tensorM = !% \(b1, b2) => (b1 && b2 ** \() => ((), ())) ||| Corresponds to the Applicative instance in `Prelude.Types` ||| It behaves like a cartesian product, but compared to `Prelude.Types` @@ -37,6 +42,10 @@ TensorMonoid List where tensorN = toState 1 tensorM = !% \(n, m) => (n * m ** splitFinProd DefaultLayoutOrder) +export +SeqMonoid List where + seqM = !% \(n <| contentM) => (sum contentM ** splitFinProdDep contentM) + {-- It is usually said that List has two applicative structures: one defined above, and another one defined by `zipList` (Section 3 of @@ -52,28 +61,51 @@ Applicative List where --} -||| Covers vectors, among others -||| For vecotrs produces a `zip` operation +||| Covers pairs, vectors, streams, grids, among others +||| For vectors produces a `zip` operation export IsNaperian c => TensorMonoid c where - tensorN @{(MkIsNaperian pos)} = toState () - tensorM @{(MkIsNaperian pos)} = !% \((), ()) => (() ** \i => (i, i)) + tensorN @{(MkIsNaperian _)} = toState () + tensorM @{(MkIsNaperian _)} = !% \((), ()) => (() ** \i => (i, i)) +||| When a container `c` is Naperian, then `c >< c` is isomorphic to `c >@ c` +||| Meaning this interface follows directly +||| Vectors also form a *graded* monad, which isn't implemented here export -IsCubical c => SeqMonoid c where - seqN @{MkIsCubical n} = toState () - seqM @{MkIsCubical n} = !% \(() <| _) => (() ** \i => (i ** i)) +IsNaperian c => SeqMonoid c where + seqM @{MkIsNaperian pos} = compToTensor {d=c} %>> tensorM + + +||| experiment, does this work? +diagonalAroundMiddle : IsNaperian c => + (f : c >@ c =%> c) -> + c >@ d >@ c =%> c >@ d + public export -diagonal : {c : Cont} -> - TensorMonoid c => - IsNaperian c => +join : SeqMonoid c => Tensor [c, c] =%> Tensor [c] -diagonal = transformToHancock {shape=[c, c]} - %>> (!% \(x, (y, ())) => let (z ** kz) = (%!) tensorM (x, y) - in (z ** \z' => let (x', y') = kz z' - in (x', y', ()))) - %>> !% \x => (x <| (\_ => ()) ** fst) +join = (id >@ rightUnit) + %>> seqM + %>> rightUnitInv + +public export +cojoin : SeqComonoid c => + Tensor [c] =%> Tensor [c, c] +cojoin = rightUnit + %>> seqComult + %>> (id >@ rightUnitInv) + +public export +diagonal : IsNaperian c => + Tensor [c, c] =%> Tensor [c] +diagonal = join + +public export +codiagonal : TensorMonoid c => + Tensor [c] =%> Tensor [c, c] +codiagonal = ?cojoinn + namespace BinTreeUtils public export diff --git a/src/Data/Container/Base/Product/Interfaces.idr b/src/Data/Container/Base/Product/Interfaces.idr index 63c6170..b585da6 100644 --- a/src/Data/Container/Base/Product/Interfaces.idr +++ b/src/Data/Container/Base/Product/Interfaces.idr @@ -4,31 +4,45 @@ import public Data.List.Quantifiers import Data.Container.Base.Object.Definition import Data.Container.Base.Morphism.Definition -import Data.Container.Base.Morphism.Instances import Data.Container.Base.Extension.Definition import Data.Container.Base.Product.Definitions import Data.Container.Base.Object.Instances -import Data.Container.Base.Product.Definitions +import Data.Container.Base.Morphism.Instances +||| Its extension is an applicative functor +||| All Naperian containers, BinTree, BinTreeLeaf, List, Maybe,... public export interface TensorMonoid (0 c : Cont) where tensorN : Scalar =%> c - tensorM : (c >< c) =%> c + tensorM : c >< c =%> c + +public export +interface TensorComonoid (0 c : Cont) where + tensorCounit : c =%> Scalar + tensorComult : c =%> c >< c + +||| Its extension is a monad +||| Just as Applicative => Monad, here TensorMonoid => SeqMonoid +public export +interface TensorMonoid c => SeqMonoid (0 c : Cont) where + seqM : c >@ c =%> c +||| These are directed containers, a.k.a. categories +||| Does this interface constraint follow analogously? public export -interface SeqMonoid (0 c : Cont) where - seqN : Scalar =%> c - seqM : (c >@ c) =%> c +interface TensorComonoid c => SeqComonoid (0 c : Cont) where + seqComult : c =%> c >@ c public export interface CoprodMonoid (0 c : Cont) where plusN : Empty =%> c - plusM : (c >+< c) =%> c + plusM : c >+< c =%> c +||| Its extension is an Alternative? public export interface ProdMonoid (0 c : Cont) where prodN : UnitCont =%> c - prodM : (c >*< c) =%> c + prodM : c >*< c =%> c public export pairExtensions : Ext c a -> Ext d b -> Ext (c >< d) (a, b) @@ -46,7 +60,7 @@ TensorMonoid c => Applicative (Ext c) where public export [FromSeq] SeqMonoid c => Applicative (Ext c) where - pure x = seqN.fwd () <| const x + pure x = tensorN.fwd () <| const x (f <| f') <*> (x <| x') = ?notAThing public export @@ -59,19 +73,8 @@ public export pure x = prodN.fwd () <| const x (<*>) = ?notAThing2 - public export ProdMonoid c => Alternative (Ext c) using FromProd where empty = let (p1 ** p2) = (%! prodN) () in p1 <| absurd . p2 (<|>) (a <| a') (b <| b') = let (m1 ** m2) = (%! prodM) (a, b) in m1 <| either a' b' . m2 - -||| The products `><` and `>@` coincide for Naperian containers -||| The lens below is one part of an isomorphism -napProductIdentity : {p, q : Type} -> - (Nap p >< Nap q) =%> (Nap p >@ Nap q) -napProductIdentity = !% \((), ()) => (() <| (\_ => ()) ** \(p ** q) => (p, q)) - - -compToTensor : {c, d : Cont} -> (c >@ d) =%> (c >< d) -compToTensor = !% \(cShp <| content) => ((cShp, ?dShp) ** ?compToTensor_rhs) diff --git a/src/Data/Container/Base/Properties/Definitions.idr b/src/Data/Container/Base/Properties/Definitions.idr new file mode 100644 index 0000000..49ce592 --- /dev/null +++ b/src/Data/Container/Base/Properties/Definitions.idr @@ -0,0 +1,151 @@ +module Data.Container.Base.Properties.Definitions + +import Data.Fin +import Data.Finite + +import Data.Container.Base.Object.Definition +import Data.Container.Base.Morphism.Definition +import Data.Container.Base.Extension.Definition + +import Misc + +{------------------------------------------------------------------------------- +States various properties a container can have +Some of these mirror aliases in `Object.Definitions`, they're purposefully +separated with imports, and don't refer to each other + +These are thought of as extensional declarations: we need not know anything +about concrete instances to define these? + +-------------------------------------------------------------------------------} + +||| Convenience datatype for storing the data that a container `c` has an +||| interface `i` on its positions +public export +data InterfaceOnPositions : (c : Cont) -> (i : Type -> Type) -> Type where + ||| For every shape `s` the set of positions `c.Pos s` has that interface + MkI : ((s : c.Shp) -> i (c.Pos s)) -> InterfaceOnPositions c i + +||| A container is finite when for every shape the set of positions is finite. +||| Examples: vectors, lists, but also finite binary trees. +||| Note, provision of a finite instance for trees requires a choice of a tree +||| traversal. (All of these choices isomorphic, but are necessary to make) +public export +IsFinite : Cont -> Type +IsFinite c = InterfaceOnPositions c Finite + +||| A container is non-dependent when positions do not depend on shapes +public export +data IsNonDep : Cont -> Type where + MkIsNonDep : (s, p : Type) -> IsNonDep ((_ : s) !> p) + +||| Used in learning, where we want to know that the tangent space over a +||| particular parameter is equal to the parameter space itself +public export +data IsConst : Cont -> Type where + ItIsConst : (p : Type) -> IsConst ((_ : p) !> p) + +||| Following the flat-sharp terminology +public export +data IsFlat : Cont -> Type where + ItIsFlat : (s : Type) -> IsFlat ((_ : s) !> Unit) + +public export +data IsSharp : Cont -> Type where + ItIsSharp : (s : Type) -> IsSharp ((_ : s) !> Void) + +namespace Naperian + ||| Local alias used solely to keep `MkIsNaperian`'s index out of raw + ||| record-constructor form. Without this, pattern matches like + ||| `(MkIsNaperian p :: as)` against `All IsNaperian shape` trip an + ||| Idris 2 coverage-checker incompleteness (see issue #3721). + ||| Definitionally equal to `Nap pos` from `Object.Instances`, + ||| but defined here to avoid an import cycle. + public export + NaperianCont : Type -> Cont + NaperianCont pos = (_ : Unit) !> pos + + ||| A container is Naperian when the set of shapes is `Unit`, i.e. when it + ||| contains only one set of positions. + ||| Examples: Scalar, UnitCont, Pair, Vect n, Stream. + ||| Notably, Naperian does not imply Finite, as the `Stream` example shows. + public export + data IsNaperian : Cont -> Type where + MkIsNaperian : (pos : Type) -> IsNaperian (NaperianCont pos) + + public export + LogHelper : IsNaperian c => Type + LogHelper @{MkIsNaperian pos} = pos + + public export + Log : (0 c : Cont) -> IsNaperian c => Type + Log _ @{MkIsNaperian pos} = pos + + public export + naperianPosEq : IsNaperian c => {0 x, y : c.Shp} -> c.Pos x = c.Pos y + naperianPosEq @{MkIsNaperian _} = Refl + +namespace Cubical + ||| Will be removed later, temp fix for now as otherwise the coverage + ||| checker complains + public export + CubicalCont : Nat -> Cont + CubicalCont n = (_ : Unit) !> Fin n + + ||| A container is cubical whenever it is Finite and Naperian + ||| Effectively, captures `Vect n` containers, up to isomorphism + ||| Examples: for any `n : Nat`, `Vect n`. Those are all the examples, up to + ||| isomorphism. Notably, this also includes a container whose unique set of + ||| positions is the set of positions of a binary tree of a particular shape. + ||| This is isomorphic to the `Vect k` container, for some `k`, assuming a + ||| choice of tree traversal (though all of them yield the same `k`). Here + ||| `k` corresponds to the number of positions in that binary tree + public export + data IsCubical : Cont -> Type where + MkIsCubical : (n : Nat) -> IsCubical (CubicalCont n) + + public export + dimHelper : IsCubical c -> Nat + dimHelper (MkIsCubical n) = n + + ||| We call dimension the size of the set of positions of a finite container + public export + dim : (0 c : Cont) -> IsCubical c => Nat + dim _ @{ic} = dimHelper ic + + ||| Every cubical container is `Nap (Fin n)` with `n = dim ic` (used for rewrites). + public export + isCubicalContEq : IsCubical d -> d = ((_ : Unit) !> (Fin (dim {c=d}))) + isCubicalContEq (MkIsCubical n) = Refl + + +namespace IsFoldable + ||| A container is foldable if `c ≃ List` + ||| That is, there ought to exist a dependent lens `c =%> List` and back + ||| Here we only encode one part of this + public export + interface IsFoldable (0 c : Cont) where + constructor MkIsFoldable + mapToList : c =%> ((n : Nat) !> Fin n) + + +namespace IsConcrete + ||| Many datatypes in the Idris standard library are already + ||| concrete representations of particular containers + public export + interface IsConcrete (0 c : Cont) where + constructor MkIsConcrete + concreteType : Type -> Type + concreteFunctor : Functor concreteType + fromConcreteTy : concreteType a -> Ext c a + toConcreteTy : Ext c a -> concreteType a + + public export prefix 0 >#, #> + + public export + (>#) : IsConcrete c => concreteType {c=c} a -> Ext c a + (>#) = fromConcreteTy + + public export + (#>) : IsConcrete c => Ext c a -> concreteType {c=c} a + (#>) = toConcreteTy \ No newline at end of file diff --git a/src/Data/Container/Base/Properties/Instances.idr b/src/Data/Container/Base/Properties/Instances.idr new file mode 100644 index 0000000..ca1e50c --- /dev/null +++ b/src/Data/Container/Base/Properties/Instances.idr @@ -0,0 +1,225 @@ +module Data.Container.Base.Properties.Instances + +import Data.Fin +import Data.Vect + +import Data.Container.Base.Object.Definition +import Data.Container.Base.Morphism.Definition +import Data.Container.Base.Extension.Definition +import Data.Container.Base.Properties.Definitions +import Data.Container.Base.Product.Definitions + +import Data.Container.Base.Object.Instances +import Data.Container.Base.Extension.Instances +import Data.Container.Base.Morphism.Instances + +import Data.Trees +import Data.Functor.Products +import Data.Container.Base.TreeUtils + +import Misc + +%hide Data.Vect.fromList + +public export +IsConcrete Scalar where + concreteType = id + concreteFunctor = MkFunctor id + fromConcreteTy = pure + toConcreteTy (() <| f) = f () + +public export +IsConcrete Maybe where + concreteType = Maybe + concreteFunctor = %search + + fromConcreteTy Nothing = False <| absurd + fromConcreteTy (Just x) = True <| \() => x + + toConcreteTy (False <| _) = Nothing + toConcreteTy (True <| f) = Just (f ()) + +public export +IsConcrete Pair where + concreteType = \a => Pair a a + concreteFunctor = MkFunctor $ \f, (x, y) => (f x, f y) + fromConcreteTy (x, y) = () <| \case False => x; True => y + toConcreteTy (() <| f) = (f False, f True) + +public export +(icc : IsConcrete c) => (icd : IsConcrete d) => IsConcrete (c >< d) where + concreteType = concreteType @{icc} >< concreteType @{icd} + concreteFunctor = ?concreteFunctorHancockProduct + fromConcreteTy = ?fromConcreteTyHancockProduct + toConcreteTy = ?toConcreteTyHancockProduct + +public export +(icc : IsConcrete c) => (icd : IsConcrete d) => IsConcrete (c >@ d) where + concreteType = concreteType @{icc} . concreteType @{icd} + concreteFunctor = MkFunctor $ \f => ?concreteFunctorCompositionProduct + fromConcreteTy = ?fromConcreteTyCompositionProduct + toConcreteTy = ?toConcreteTyCompositionProduct + + +||| For recursive types we need to extract out the conversion functions +namespace List + public export + fromList : List a -> List' a + fromList [] = (0 <| absurd) + fromList (x :: xs) = let (l <| c) = fromList xs + in (S l <| cons x c) + + public export + toList : List' a -> List a + toList (0 <| _) = [] + toList ((S k) <| ind) = head ind :: toList (k <| tail ind) + + public export + IsConcrete List where + concreteType = List + concreteFunctor = %search + fromConcreteTy = fromList + toConcreteTy = toList + +namespace Vect + public export + fromVect : Vect n a -> Vect' n a + fromVect v = () <| \i => index i v + + public export + toVect : {n : Nat} -> Vect' n a -> Vect n a + toVect (_ <| index) = Vect.Fin.tabulate index + + public export + {n : Nat} -> IsConcrete (Vect n) where + concreteType = Vect n + concreteFunctor = %search + fromConcreteTy = fromVect + toConcreteTy = toVect + +namespace BinTreeSame + public export + fromBinTreeSame : BinTreeSame a -> BinTree' a + fromBinTreeSame (Leaf x) = LeafS <| \_ => x + fromBinTreeSame (Node x lt rt) = + let (fblt, fbrt) = (fromBinTreeSame lt, fromBinTreeSame rt) + in NodeS (shapeExt fblt) (shapeExt fbrt) <| \case + AtNode => x + GoLeft posL => index fblt posL + GoRight posR => index fbrt posR + + public export + toBinTreeSame : BinTree' a -> BinTreeSame a + toBinTreeSame (LeafS <| index) = Leaf (index AtLeaf) + toBinTreeSame (NodeS lt rt <| index) = + Node (index AtNode) + (toBinTreeSame (lt <| index . GoLeft)) + (toBinTreeSame (rt <| index . GoRight)) + + public export + IsConcrete BinTree where + concreteType = BinTreeSame + concreteFunctor = %search + fromConcreteTy = fromBinTreeSame + toConcreteTy = toBinTreeSame + +namespace BinTreeNode + public export + fromTreeHelper : BinTreePosNode LeafS -> a + fromTreeHelper AtNode impossible + fromTreeHelper (GoLeft x) impossible + fromTreeHelper (GoRight x) impossible + + public export + fromBinTreeNode : BinTreeNode a -> BinTreeNode' a + fromBinTreeNode (Leaf ()) = LeafS <| fromTreeHelper + fromBinTreeNode (Node node leftTree rightTree) + = let (fblt, fbrt) = (fromBinTreeNode leftTree, fromBinTreeNode rightTree) + in (NodeS (shapeExt fblt) (shapeExt fbrt) <| \case + AtNode => node + GoLeft posL => index fblt posL + GoRight posR => index fbrt posR) + + public export + toBinTreeNode : BinTreeNode' a -> BinTreeNode a + toBinTreeNode (LeafS <| index) = Leaf () + toBinTreeNode (NodeS lt rt <| index) = + Node (index AtNode) + (toBinTreeNode (lt <| index . GoLeft)) + (toBinTreeNode (rt <| index . GoRight)) + + public export + IsConcrete BinTreeNode where + concreteType = BinTreeNode + concreteFunctor = %search + fromConcreteTy = fromBinTreeNode + toConcreteTy = toBinTreeNode + +namespace BinTreeLeaf + public export + fromBinTreeLeaf : BinTreeLeaf a -> BinTreeLeaf' a + fromBinTreeLeaf (Leaf leaf) = LeafS <| \_ => leaf + fromBinTreeLeaf (Node node lt rt) = + let (fblt, fbrt) = (fromBinTreeLeaf lt, fromBinTreeLeaf rt) + in NodeS (shapeExt fblt) (shapeExt fbrt) <| \case + GoLeft posL => index fblt posL + GoRight posR => index fbrt posR + + public export + toBinTreeLeaf : BinTreeLeaf' a -> BinTreeLeaf a + toBinTreeLeaf (LeafS <| content) = Leaf (content AtLeaf) + toBinTreeLeaf (NodeS l r <| content) = + Node' (toBinTreeLeaf (l <| content . GoLeft)) + (toBinTreeLeaf (r <| content . GoRight)) + + public export + IsConcrete BinTreeLeaf where + concreteType = BinTreeLeaf + concreteFunctor = %search + fromConcreteTy = fromBinTreeLeaf + toConcreteTy = toBinTreeLeaf + + +public export +foldList : (a -> b -> b) -> b -> List' a -> b +foldList f z (0 <| _) = z +foldList f z ((S k) <| content) + = f (head content) $ foldList f z (k <| tail content) + +public export +IsFoldable c => Foldable (Ext c) where + foldr @{(MkIsFoldable toL)} f z = foldList f z . extMap toL + +public export +IsFoldable List where + mapToList = id + +public export +{n : Nat} -> IsFoldable (Vect n) where + mapToList = vectToList + +||| Requires making a choice of traversal order +||| Is there a good reason to prefer a particular order? +public export +IsFoldable BinTreeLeaf where + mapToList = inorder + +public export +IsFoldable BinTreeNode where + mapToList = inorder + +public export +IsFoldable BinTree where + mapToList = inorder + +-- old +-- ||| Indexing an element of `xs` and then applying `f` to it is the same as +-- ||| mapping `f` over xs, and then indexing the result +-- public export +-- mapIndexPreserve : {0 f : a -> b} -> +-- (xs : List a) -> +-- (i : Fin (length (f <$> xs))) -> +-- f (index' xs (rewrite sym (lengthMap {f=f} xs) in i)) +-- = index' (f <$> xs) i +-- mapIndexPreserve (x :: xs) FZ = Refl +-- mapIndexPreserve (x :: xs) (FS j) = mapIndexPreserve xs j \ No newline at end of file diff --git a/src/Data/Container/Base/Quantifiers.idr b/src/Data/Container/Base/Quantifiers.idr index bf09641..b93cfe4 100644 --- a/src/Data/Container/Base/Quantifiers.idr +++ b/src/Data/Container/Base/Quantifiers.idr @@ -7,6 +7,7 @@ import Data.Vect.Quantifiers import Data.Container.Base.Object.Definition ||| Quantifiers for lists +||| The predicate for each container's shape is considered its positions ||| We can have All/Any on shapes, and All/Any on positions ||| We get 3 valid combinations, since AnyAll=AnyAny. ||| That overlap is called AnyShpPos below diff --git a/src/Data/Functor/Algebra.idr b/src/Data/Functor/Algebra.idr index 4bdf0fd..356837f 100644 --- a/src/Data/Functor/Algebra.idr +++ b/src/Data/Functor/Algebra.idr @@ -2,7 +2,7 @@ module Data.Functor.Algebra import Data.Vect -import Data.Tree +import Data.Trees import Misc {------------------------------------------------------------------------------- diff --git a/src/Data/Functor/Naperian.idr b/src/Data/Functor/Naperian.idr index 31fe9b3..b221f5f 100644 --- a/src/Data/Functor/Naperian.idr +++ b/src/Data/Functor/Naperian.idr @@ -4,6 +4,7 @@ import Data.Vect %hide Data.Vect.transpose +-- todo this isn't necessary anymore? -- Needed to define transposition, and diagonal elements {- Lists -> not Naperian! Their shape isn't uniform (they can be of different lengths) diff --git a/src/Data/Functor/Products.idr b/src/Data/Functor/Products.idr new file mode 100644 index 0000000..fd00d29 --- /dev/null +++ b/src/Data/Functor/Products.idr @@ -0,0 +1,18 @@ +module Data.Functor.Products + +public export infixr 3 >< -- Day convolution product +public export infixr 3 >*< -- Categorical product +public export infixr 3 >+< -- Categorical coproduct +-- Composition of functors is simply `.` + +public export +(>*<) : (Type -> Type) -> (Type -> Type) -> (Type -> Type) +(>*<) f g a = (f a, g a) + +public export +(>+<) : (Type -> Type) -> (Type -> Type) -> (Type -> Type) +(>+<) f g a = Either (f a) (g a) + +public export +(><) : (Type -> Type) -> (Type -> Type) -> (Type -> Type) +(><) f g a = ?dayConvolutionProduct diff --git a/src/Data/Layout.idr b/src/Data/Layout.idr index 1124f33..ddd2eea 100644 --- a/src/Data/Layout.idr +++ b/src/Data/Layout.idr @@ -4,6 +4,7 @@ import Data.Fin.Split import Data.List.Quantifiers import Language.Reflection import Derive.Prelude +import Misc %language ElabReflection @@ -64,4 +65,13 @@ indexFinProd : {m, n : Nat} -> Fin (m * n) indexFinProd RowMajor row col = indexProd row col indexFinProd ColumnMajor row col = - replace {p = Fin} (sym $ multCommutative m n) (indexProd {m=n} {n=m} col row) \ No newline at end of file + replace {p = Fin} (sym $ multCommutative m n) (indexProd {m=n} {n=m} col row) + +||| Like `splitFinProd`, but here the order is fixed for us by dependency +public export +splitFinProdDep : {n : Nat} -> (content : Fin n -> Nat) -> + Fin (sum content) -> (i : Fin n ** Fin (content i)) +splitFinProdDep {n = 0} content x = ?shouldBeImpossibleToReach +splitFinProdDep {n = (S k)} content x = case splitSum x of + Left y => (FZ ** y) + Right y => let (i ** j) = splitFinProdDep (content . FS) y in (FS i ** j) \ No newline at end of file diff --git a/src/Data/ScientificNotation.idr b/src/Data/ScientificNotation.idr new file mode 100644 index 0000000..c1db31b --- /dev/null +++ b/src/Data/ScientificNotation.idr @@ -0,0 +1,207 @@ +module Data.ScientificNotation + +import Data.List +import Data.Nat +import Data.String + +import Misc + +{------------------------------------------------------------------------------- +{------------------------------------------------------------------------------- +This file contains custom scientific formatting for numeric types. + +While Idris' `Show` for numeric primitives already exists, it: +* does not permit precise control over ranges the scientific notation is invoked +* is backend-dependent, meaning that Scheme formats differently than JS +* does not allow us to do a two-pass formatting such that global tensor + information dictates the render for a particular element. (I.e. if any number + is within scientific notation range, then all numbers get rendered in scientific notation) + +-------------------------------------------------------------------------------} +-------------------------------------------------------------------------------} + +||| Interface for displaying numeric types that dynamically switches between +||| standard and scientific notation based on magnitude +||| If `forceScientific` is `True`, then we render in scientific notation no +||| matter what the value is +public export +interface Num a => ScientificDisplay a where + -- ideally we'd use something this: {default False forceScientific : Bool} -> + showSci : a -> String + +||| Magnitude above which `Double` values switch to scientific notation. +public export +sciUpperM : Double +sciUpperM = 1.0e6 + +||| Magnitude below which `Double` values switch to scientific notation. +public export +sciLowerM : Double +sciLowerM = 1.0e-4 + +||| Default precision used by numeric primitives with scientific notation +||| This is maximum, trailing zeros are removed. +public export +defaultScientificPrecision : Nat +defaultScientificPrecision = 4 + +||| Symbol used to denote the exponent in scientific notation, i.e. `1.0e+03` +public export +sciSymbol : Char +sciSymbol = 'e' + +||| True when `d` is non-zero and outside the range +public export +needsScientific : Double -> Bool +needsScientific d = d /= 0.0 && (m < sciLowerM || m >= sciUpperM) + where m = abs d + +-------------------------------------------------------------------------------- +-- Formatting primitives +-------------------------------------------------------------------------------- + +||| Format an integer exponent as `e+XX` or `e-XX`. +||| `formatExp 5 == "e+05"` +||| `formatExp -123 == "e-123"` +public export +formatExp : Integer -> String +formatExp n = singleton sciSymbol ++ sign ++ applyWhen (length ds < 2) ("0" ++) ds + where sign : String + sign = if n < 0 then "-" else "+" + ds : String + ds = show (abs n) + + +||| Multiply a decimal number by `10^prec`, and round it to the nearest integer +||| `round 2 3.14159 == 314` +||| `round 3 3.14159 == 3142` +||| `round 3 3.14100 == 3141` +roundScaled : (prec : Nat) -> Double -> Integer +roundScaled prec d = cast (floor (abs d * pow 10.0 (cast prec) + 0.5)) + +||| Format a non-negative integer as a decimal string. +||| `prec` controls how many digits appear after the decimal point +||| Input is taken to already be multiplied by `10^prec` +||| Zero-padding is added on the left when there are not enough digits. +||| ``` +||| formatDigits 5 314159 == "3.14159" +||| formatDigits 3 314159 == "314.159" +||| formatDigits 2 5 == "0.05" +||| formatDigits 0 42 == "42" +||| ``` +public export +formatDigits : (prec : Nat) -> (digits : Integer) -> String +formatDigits 0 n = show n +formatDigits prec n = substr 0 nDig padded ++ "." ++ substr nDig prec padded + where len : Nat + len = length (show n) + padded : String + padded = applyWhen (len <= prec) + (pack (replicate (S prec `minus` len) '0') ++) + (show n) + nDig : Nat -- number of digits to the left of the decimal point + nDig = length padded `minus` prec + +||| Format a Double in standard notation with a fixed number of decimal places. +||| Rounds the last decimal to the nearest digit, and possibly pads with zeros, +||| if precision is greater than the number of digits after the decimal point. +||| ``` +||| showDoublePrecision 3 3.14159 == "3.142" +||| showDoublePrecision 3 3.14100 == "3.141" +||| showDoublePrecision 0 3.14159 == "3" +||| showDoublePrecision 4 100.0 == "100.0000" +||| ``` +public export +showDoublePrecision : (precision : Nat) -> Double -> String +showDoublePrecision prec d = applyWhen (d < 0) ("-" ++) $ + formatDigits prec (roundScaled prec d) + +||| Decompose `|d|` into `(mantissa, exponent)` with `1 ≤ mantissa < 10`, +||| correcting the "one-decade" drift that `floor . log10` can produce at +||| decade boundaries (e.g. `log10 0.999... = -1.4e-16`, not 0). +||| Pre: `d ≠ 0`. +decimalDecompose : Double -> (Double, Integer) +decimalDecompose d = + let m = abs d + e = cast (floor (log m / log 10.0)) + m0 = m / pow 10.0 (cast e) + in if m0 >= 10.0 then (m0 / 10.0, e + 1) + else if m0 < 1.0 then (m0 * 10.0, e - 1) + else (m0, e) + +||| Format a Double in scientific notation with a fixed number of +||| decimal places in the mantissa. +||| +||| ``` +||| showDoubleScientific 5 3.14159 == "3.14159e+00" +||| showDoubleScientific 5 (-0.005) == "-5.00000e-03" +||| showDoubleScientific 5 100.0 == "1.00000e+02" +||| showDoubleScientific 5 0.0000001 == "1.00000e-07" +||| ``` +public export +showDoubleScientific : (precision : Nat) -> Double -> String +showDoubleScientific prec 0.0 = formatDigits prec 0 ++ formatExp 0 +showDoubleScientific prec d = + applyWhen (d < 0) ("-" ++) $ formatDigits prec mantInt ++ formatExp expFinal + where + decomp : (Double, Integer) + decomp = decimalDecompose d + + -- Round mantissa to a scaled integer with `prec` digits of precision. + rounded : Integer + rounded = roundScaled prec (fst decomp) + + -- If rounding pushed the mantissa to ≥ 10, carry one decade. + overflow : Bool + overflow = rounded >= cast (pow 10.0 (cast (S prec))) + + mantInt : Integer + mantInt = applyWhen overflow (`div` 10) rounded + + expFinal : Integer + expFinal = applyWhen overflow (+ 1) (snd decomp) + +||| Remove trailing zeros after the decimal point, keeping at least one. +||| Handles scientific notation, i.e.`"1.0000e-07"` becomes `"1.0e-07"` +public export +trimTrailingZeros : String -> String +trimTrailingZeros s = case break (== sciSymbol) (unpack s) of + (mant, expPart) => case break (== '.') mant of + (_, []) => s + (whole, _ :: frac) => + let trimmed : String + trimmed = case reverse (dropWhile (== '0') (reverse frac)) of + [] => "0" + xs => pack xs + in pack whole ++ "." ++ trimmed ++ pack expPart + +-------------------------------------------------------------------------------- +-- ScientificDisplay instances +-------------------------------------------------------------------------------- + +||| Render a value using `showDoubleScientific` after casting to `Double`. +||| Used by integer-like instances when the magnitude warrants sci notation. +showAsScientific : Cast a Double => a -> String +showAsScientific n = trimTrailingZeros $ + showDoubleScientific defaultScientificPrecision (cast n) + +||| Format a Double for display: fixed notation for values in the range +||| `[scientificLowerMagnitude, scientificUpperMagnitude)`, scientific +||| notation outside that range, with redundant trailing zeros removed. +public export +ScientificDisplay Double where + showSci d = trimTrailingZeros $ case needsScientific d of + True => showDoubleScientific defaultScientificPrecision d + False => showDoublePrecision defaultScientificPrecision d + +public export +ScientificDisplay Integer where + showSci n = case cast (abs n) < sciUpperM of + True => show n -- builtin `Show` never uses scientific notation + False => showAsScientific n + +public export +ScientificDisplay Nat where + showSci n = case cast n < sciUpperM of + True => show n -- builtin `Show` never uses scientific notation + False => showAsScientific n diff --git a/src/Data/Tensor/Axis.idr b/src/Data/Tensor/Axis.idr deleted file mode 100644 index cc1b2ad..0000000 --- a/src/Data/Tensor/Axis.idr +++ /dev/null @@ -1,280 +0,0 @@ -module Data.Tensor.Axis - -import public Decidable.Equality -import Data.Vect.Elem -import Data.Vect.Quantifiers - -import Data.Container.Base -import Data.Unique.Vect -import Misc - -{------------------------------------------------------------------------------- -{------------------------------------------------------------------------------- - -~~~~~~~~~~~~~~~ -Design choices: -~~~~~~~~~~~~~~~ - -1) Persistent axis names. - -Instead of transient axis names (bound within a function using the tensor, erased with the completion of the said function), axis names persist with the lifetime of the tensor. - -2) Axis declarations persist globally, but are only checked for consistency at call sites. - -This means that axis names are checked for consistency at each call site, rather than at declaration sites. In a proper programming language we'd track names at declaration sites and raise errors if inconsistencies/duplicates are detected, here we opt for a more pragmatic approach. - -3) Duplicate axis names within a tensor are allowed, as long as they refer to the same container. - -Otherwise it would not be clear how to take the diagonal/trace of a matrix while referring only to the axes: they'd have to have different names. - -4) Does tensor contraction allow duplicate axis names? - - - -Does tensor contraction allow duplicate axis names - * in the input (yes, this is what Einsum also allows) - * in the output (no, because otherwise its not clear what should happen) - * this means that we can't write `einsum("i,i->ii")` -3) How does contraction work? - 3.1) Given `t : Tensor [BatchSize, BatchSize] Double`, what is `dotGeneral t`? - -Need to figure out how `reduce name t` acts when: -1) `name="BatchSize"` and `t : Tensor [BatchSize, BatchSize] Double` - - Should sum up the diagonal? -2) `name="BatchSize"` and `t : Tensor [BatchSize] Double` - - Should sum up the vector? -3) `name="BatchSize"` and `t : Tensor [BatchSize, SeqLen, BatchSize] Double` - - Should sum up the diagonal slices of SeqLen - -I suppose this is about iterators -iterating through - - - - ---- Consistency checking: ----------------- -We check consistency at each call site. -Alternatively if we were building a programming languge we'd check consistency with each declaration. That is, writing something like: -```idris -BatchSize1 : Axis -BatchSize1 = "batchSize" ~> Vect 128 - -BatchSize2 : Axis -BatchSize2 = "batchSize" ~> Vect 129 -``` -would throw an error on the line `BatchSize2 = ...` because we're redeclaring "batchSize" which already exists. - -------------------------------------------- - -Similar projects/ideas: -* XArray: https://docs.xarray.dev/en/stable/ (persistent axis names) -* Haliax: https://github.com/marin-community/haliax - --------------------------------------------------------------------------------} --------------------------------------------------------------------------------} - -||| The name for an axis is an arbitrary string -public export -AxisName : Type -AxisName = String - -public export infixr 0 ~> -- Constructor for container-based axes -public export infixr 0 ~~> -- 'Constructor' for cubical axes - -||| An axis is a container (the "size" of the axis) together with its name -public export -record Axis where - constructor (~>) - name : AxisName - cont : Cont - -public export -rename : Axis -> AxisName -> Axis -rename a str = str ~> a.cont - - -||| In some cases we TensorType might need to assign a default name to an axis, -||| one which is internal and will not be exposed to the user. -||| This is the default name for such cases -public export -TTInternalName : AxisName -TTInternalName = "__tensortype_tempaxis__" - - -namespace Cubical - ||| A "constructor" for cubical axes - public export - (~~>) : AxisName -> Nat -> Axis - (~~>) axisName n = axisName ~> Vect n - - ||| Follows the pattern of `IsCubical` from `Data.Container.Object.Instances` - public export - data IsCubical : Axis -> Type where - MkIsCubical : (name : AxisName) -> (n : Nat) -> IsCubical (name ~~> n) - - public export - dimHelper : {0 a : Axis} -> IsCubical a -> Nat - dimHelper (MkIsCubical _ n) = n - - public export - dim : (0 a : Axis) -> IsCubical a => Nat - dim _ @{ic} = dimHelper ic - - public export - data IsNaperian : Axis -> Type where - MkIsNaperian : (name : AxisName) -> (pos : Type) -> - IsNaperian (name ~> Nap pos) - - public export - LogHelper : {0 a : Axis} -> IsNaperian a => Type - LogHelper @{MkIsNaperian _ pos} = pos - - public export - Log : (0 a : Axis) -> IsNaperian a => Type - Log a @{inn} = LogHelper @{inn} - - public export - toContNaperian : {0 a : Axis} -> IsNaperian a -> IsNaperian a.cont - toContNaperian (MkIsNaperian name pos) = MkIsNaperian pos - - public export - cubicalShapeHelper : {0 shape : Vect r Axis} -> - All IsCubical shape -> List Nat - cubicalShapeHelper [] = [] - cubicalShapeHelper (ic :: ns) = dimHelper ic :: cubicalShapeHelper ns - - ||| Given a list of cubical axes, return the list of their dimensions - public export - cubicalShape : (0 shape : Vect r Axis) -> All IsCubical shape => List Nat - cubicalShape _ @{ac} = cubicalShapeHelper ac - - ||| Size of a cubical tensor, i.e. its number of elements - public export - size : (0 shape : Vect r Axis) -> (ac : All IsCubical shape) => Nat - size ss = prod (cubicalShape ss) - - -namespace TensorShape - mutual - public export - data TensorShape : (rank : Nat) ->Type where - Nil : TensorShape 0 - (::) : (a : Axis) -> (as : TensorShape k) -> - (ac : NewAxisConsistent a as) => - TensorShape (S k) - - public export - toVect : TensorShape k -> Vect k Axis - toVect [] = [] - toVect (a :: as) = a :: toVect as - - public export - data NewAxisConsistent : Axis -> TensorShape k -> Type where - NewAxis : {0 a : Axis} -> {0 as : TensorShape k} -> - NotElem a.name (Axis.name <$> toVect as) -> - NewAxisConsistent a as - ExistingAxis : {0 a : Axis} -> {0 as : TensorShape k} -> - (e : Elem a.name (Axis.name <$> toVect as)) -> - (index (elemToFin e) (toVect as)).cont = a.cont -> - NewAxisConsistent a as - - public export - toList : TensorShape k -> List Axis - toList [] = [] - toList (a :: as) = a :: toList as - - ||| Convenience function, turns it also into a list - ||| Because `Data.Container` uses lists with tensors - public export - conts : TensorShape k -> List Cont - conts ts = cont <$> toList ts - - ||| Names of the axes in a tensor shape - public export - axisNames : TensorShape k -> Vect k AxisName - axisNames ts = name <$> toVect ts - - ||| Sizes of the axes in a tensor shape - public export - axisSizes : TensorShape k -> Vect k Cont - axisSizes ts = cont <$> toVect ts - - ||| Size of a tensor shape, i.e. its number of elements - public export - size : (shape : TensorShape k) -> All IsCubical (conts shape) => Nat - size shape = size (conts shape) - - test1 : TensorShape 2 - test1 = ["batchSize" ~> Vect 128, "seqLen" ~> List] - - test2 : TensorShape 3 - test2 = ["batchSize" ~> Vect 128, "seqLen" ~> List, "batchSize" ~> Vect 128] - - failing - test3 : TensorShape 2 - test3 = ["batchSize" ~> Vect 128, "batchSize" ~> Vect 13] - - -- ||| If an axis `i` can be added into a singleton list `[j]`, then - -- ||| the axis `j` can be added into a singleton list `[i]` - -- public export - -- axisConsistentSym : {i, j : Axis} -> - -- NewAxisConsistent i [j] -> NewAxisConsistent j [i] - -- axisConsistentSym (NewAxis ne) = NewAxis (notElemSym ne) - -- -- For some reason we can't pattern match on `Here`? The proof should still - -- -- be fine... - -- axisConsistentSym (ExistingAxis (There Here) _) impossible - -- axisConsistentSym (ExistingAxis (There (There later)) _) impossible - - ||| Proof that an axis name appears in a tensor shape n times - ||| The proof indirectly carries data of the exact indices where it appears - ||| Notably, can appear zero times, this case is needed for recursion - public export - data InShape : AxisName -> TensorShape k -> Nat -> Type where - Here : {as : TensorShape k} -> InShape axisName as n => - NewAxisConsistent (axisName ~> a) as => - InShape axisName ((axisName ~> a) :: as) (S n) - There : {as : TensorShape k} -> InShape axisName as n => - NewAxisConsistent a as => - InShape axisName (a :: as) n - - - ||| Recovers the axis from a shape given its name, and a prof that it is there - ||| Recovers the first occurence - public export - (.getByName) : (shape : TensorShape k) -> - (axisName : AxisName) -> - (inShape : InShape axisName shape n) -> - IsSucc n => - Axis - (.getByName) ((axisName ~> a) :: as) axisName Here = axisName ~> a - (.getByName) (a :: as) axisName (There @{is}) = as.getByName axisName is - - public export - removeAllOccurrences : {k, rank : Nat} ->(shape : TensorShape rank) -> - (toDelete : AxisName) -> - (inShape : InShape toDelete shape k) => - (m : Nat ** TensorShape m) - removeAllOccurrences {k=0} shape toDelete = (rank ** shape) - removeAllOccurrences ((toDelete ~> a) :: ss) toDelete @{Here @{is}} - = removeAllOccurrences ss toDelete @{is} - removeAllOccurrences (s :: ss) toDelete @{There @{is}} - = let (m ** ss') = removeAllOccurrences ss toDelete @{is} - in (S m ** (::) {ac=(believe_me ())} s ss') -- should write this later - - - ||| TODO rethink this function? - ||| In a tensor shape removes all but the first occurence of an axis - ||| removeDuplicates ["x" ~> 1, "y" ~> 3, "x" ~> 1] "x" = ["x" ~> 1, "y" ~> 1] - public export - removeDuplicates : {k, rank : Nat} -> (shape : TensorShape rank) -> - (axisName : AxisName) -> - (inShape : InShape axisName shape k) => - IsSucc k => - (m : Nat ** TensorShape m) - removeDuplicates shape axisName {inShape} {k = 1} - = (rank ** shape) - removeDuplicates ((_ ~> a) :: as) axisName {inShape = Here @{is}} {k = (S (S k))} - = removeDuplicates as axisName {inShape=is} - removeDuplicates (s :: as) axisName {inShape = There @{is}} {k = (S (S k))} - = let (m ** as') = removeDuplicates as axisName {inShape=is} - in (S m ** (::) {ac=(believe_me ())} s as') \ No newline at end of file diff --git a/src/Data/Tensor/Einsum/Elab.idr b/src/Data/Tensor/Einsum/Elab.idr deleted file mode 100644 index c57d52b..0000000 --- a/src/Data/Tensor/Einsum/Elab.idr +++ /dev/null @@ -1,321 +0,0 @@ -module Data.Tensor.Einsum.Elab - -import Data.DPair -import Data.List.Quantifiers -import Decidable.Equality -import Language.Reflection - -import Data.Tensor -import Data.Tensor.Einsum.Expr -import Data.Tensor.Einsum.ElabUtils -import Misc - -%language ElabReflection - ------------------------------------------------------------- ------ Elaborator Reflection for Einsum Function Generation ------------------------------------------------------------- - -||| Inductive representation of a heterogeneous list of variable-shape tensors -||| It exposes the shape information the type, and allows us to use the usual -||| list syntax like [t1, t2] with tensors of different shapes -||| This means that shape parameter can be inferred by the typechecker, instead -||| of needing to be manually supplied -||| For instance, f : TensorList shapes -> ... can consume f [m, n, k] where these are tensors of all different shaeps -public export -data TensorAList : List (List Nat) -> Type -> Type where - Nil : TensorAList [] a - (::) : Tensor sh a -> TensorAList shapes a -> TensorAList (sh :: shapes) a - - -public export -toHList : TensorAList shapes a -> HList (shapes <&> (\sh => Tensor sh a)) -toHList [] = [] -toHList (t :: ts) = t :: toHList ts - - ------------------------------------------------------------------ - - -||| Find index of the first occurence of a character in a list of lists -public export -findCharPosition : Char -> List (List Char) -> Maybe (Nat, Nat) -findCharPosition c [] = Nothing -findCharPosition c (xs :: xss) = - case findIndex (== c) xs of - Just innerIdx => Just (0, finToNat innerIdx) - Nothing => case findCharPosition c xss of - Just (outerIdx, innerIdx) => Just (S outerIdx, innerIdx) - Nothing => Nothing - -||| Compute dimension of a tensor at given position -public export -getTensorADimSize : {shapes : List (List Nat)} -> - (tensorIdx : Nat) -> - (dimIdx : Nat) -> - TensorAList shapes a -> - Maybe Nat -getTensorADimSize {shapes = []} _ _ [] = Nothing -getTensorADimSize {shapes = (sh :: shs)} Z dimIdx (t :: ts) = - case inBounds dimIdx sh of - Yes prf => Just (index dimIdx sh) - No _ => Nothing -getTensorADimSize {shapes = (sh :: shs)} (S k) dimIdx (t :: ts) = - getTensorADimSize k dimIdx ts - -||| Given a name of an axis, a list of axis names and the corresponding tensors, produce the size of that axis -public export -getSize : {shapes : List (List Nat)} -> - (outputName : Char) -> - (inputNames : List (List Char)) -> - (inputTensorAs : TensorAList shapes a) -> - Maybe Nat -getSize outputName inputNames inputTensorAs = do - (tensorIdx, dimIdx) <- findCharPosition outputName inputNames - getTensorADimSize tensorIdx dimIdx inputTensorAs - -||| Given an input string and a list of tensors, compute the output tensor shape -||| Notably doesn't throw errors related to binding inconsistencies w.r.t. the tensor list -||| We rely on elab reflection in next step to take care of that -public export -einsumComputeOutputType : {a : Type} -> {shapes : List (List Nat)} -> - String -> TensorAList shapes a -> Either EinsumParseError Type -einsumComputeOutputType exprStr ts = case parseEinsumString exprStr of - Left err => Left err - Right expr => let outputChars : List Char := toList (outputTyProj expr) - maybeNats : List (Maybe Nat) = (\c => getSize c (inputTyProj expr) ts) <$> outputChars - result : Maybe (List Nat) = sequence maybeNats - in case result of - Nothing => Left BindingInconsistency - Just listOfNats => Right (Tensor listOfNats a) - - -partial -isRight : Either a b -> b -isRight (Right x) = x - -{- -I've been playing around with elaborator reflection recently and I'm wondering whether it's possible to provide an interface by which it's possible to interact with functions that use elaborator reflection without using %runElab at call sites. - -} - - --- exUsage : (b : Bool) -> exType b --- exUsage b = %runElab (exVal b) - -einsumTestO : String -> (n : Nat) -> Either Unit Type -einsumTestO "a" n = Right (Vect n Int -> Int) -einsumTestO "b" _ = Right (Double -> List Double) -einsumTestO _ _ = Left () - -partial -einsumTest : (str : String) -> (n : Nat) -> Elab (isRight (einsumTestO str n)) -einsumTest str n = case str of - "a" => pure (\xs => foldr (+) 0 xs) - "b" => pure (\d => [d, d, d]) - - --- einsumTestImpl : (str : String) -> (n : Nat) -> isRight (einsumTestO str n) --- einsumTestImpl str n = %runElab (einsumTest str n) - -partial -esVal : List Double -esVal = let a : Elab (Vect 7 Int -> Int) := einsumTest "a" 7 - c : Elab (Double -> List Double) := einsumTest "b" 7 - ae : Vect 7 Int -> Int := %runElab (einsumTest "a" 7) - ce : Double -> List Double := %runElab (einsumTest "b" 7) - in (%runElab (einsumTest "b" 7)) 3.7 - --- Macro that provides NumPy-like einsum("ij,jk->ik", m, n) syntax --- This automatically generates einsum functions on-demand with dummy implementation -public export -partial -einsum : {a : Type} -> {shapes : List (List Nat)} -> - (exprStr : String) -> - (args : TensorAList shapes a) -> - Elab (isRight (einsumComputeOutputType exprStr args)) -einsum exprStr args = case parseEinsumString exprStr of - Left err => fail "Parse error in Einsum string: \{show err}" - Right expr@(MkEinsumExpr inputTy outputTy) => do - let inpLength : Nat := length inputTy - when (inpLength /= length shapes) $ - fail "Argument count mismatch: \{toString expr} defines \{show inpLength} inputs, but we got \{show (length shapes)} arguments" - - let uniqueVars : List Char := toList (uniqueJoin inputTy) - fnName : String := generateFunctionName exprStr - fnType : TTImp := buildEinsumFunctionType uniqueVars inputTy (toList outputTy) - -- Generate the function declaration - claimData = MkIClaimData MW Public [] (MkTy EmptyFC (NoFC (UN (Basic fnName))) fnType) - tyDecl = IClaim (MkFCVal EmptyFC claimData) - - -- Build lambda parameters for each input tensor - paramNames = [UN (Basic ("x" ++ show i)) | i <- [0..length inputTy `minus` 1]] - lambdaParams = zip paramNames inputTy - - -- Create the implementation body that matches the output tensor shape - -- Generate the output shape as a vector literal from the output type - outputShape = generateShapeVect (toList outputTy) - -- Create zeros' with the correct output shape and generic type 'a' - implBody = `(zeros {shape = ~outputShape} {a = dtype}) - - -- Build the full lambda expression - fullImpl = foldr (\(paramName, shape), body => - ILam EmptyFC MW ExplicitArg (Just paramName) (generateTensorAType shape) body) - implBody lambdaParams - -- --- - -- Create the definition using the correct IDef pattern - clause = PatClause EmptyFC (IVar EmptyFC (UN (Basic fnName))) fullImpl - funDef = IDef EmptyFC (UN (Basic fnName)) [clause] - - declare [tyDecl, funDef] - - -- pure (zeros' {a = a}) - -- fn' <- check (IVar EmptyFC (UN (Basic fnName))) - -- Now call the generated function directly with the actual arguments - case args of - [] => fail "No arguments provided" - [x] => do - fn <- check (IVar EmptyFC (UN (Basic fnName))) - pure (fn x) - [x, y] => do - fn <- check (IVar EmptyFC (UN (Basic fnName))) - pure (fn x y) - [x, y, z] => do - fn <- check (IVar EmptyFC (UN (Basic fnName))) - pure (fn x y z) - _ => fail "More than 3 arguments not yet supported" - -{- -Impossible to get rid of %runElab macro at callsites, very annoying! The code below won't compile - -} -public export -partial -einsumImpl : {a : Type} -> Num a => {shapes : List (List Nat)} -> - (exprStr : String) -> (args : TensorAList shapes a) -> - isRight (einsumComputeOutputType exprStr args) -einsumImpl exprStr args = - let t = einsum exprStr args - in ?lall -- %runElab t - -{- -runElab : Elab a -> a -`(_) : ? -> TTImp -quote : (0 val : Type) -> Elaboration m - => (0 _ : val) -> m TTImp -check : {0 expected : Type} -> Elaboration m => - TTImp -> m expected - -} - -gg : Elab Int -gg = pure 3 - -gh : Int -gh = %runElab gg - -ghQuote : Int -ghQuote = %runElab check `(3) - -m : Tensor [2, 3] Double -m = ># [[1, 2, 3], [4, 5, 6]] - -n : Tensor [3, 4] Double -n = ># [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] - --- Test the fixed einsum macro with a unique pattern -testNewPattern : Tensor [3, 2] Double -testNewPattern = %runElab einsum "ab->ba" [m] - -einsumImplementation : {a : Type} -> Num a => - {is : List Nat} -> -- free indices - {ls : List Nat} -> -- all indices - {inputShs : List (List Nat)} -> - {outputSh : List Nat} -> - {auto prf : (fromList outputSh) `IsFrom` is} -> - List1 (Exists (\sh => (Tensor sh a, (fromList sh) `IsFrom` ls))) -> - Tensor outputSh a -einsumImplementation xs = ?hmma --- -- According to the blog post, einsum works as nested for loops --- -- 1. Initialize output tensor to zeros --- let outputTensorA : Tensor outputSh a := zeros' --- -- 2. For each combination of free indices (outer loops) --- -- 3. For each combination of summation indices (inner loops) --- -- 4. Compute product of all input tensors at appropriate indices --- -- 5. Add this product to output tensor at current free index position --- in --- -- This is a simplified implementation that needs to be expanded --- -- based on the actual index manipulation and tensor operations --- -- The core idea is to iterate through all valid index combinations --- -- and perform the sum of products as described in the blog post --- case xs of --- (x ::: xs) => --- -- For now, we return the zero tensor as a placeholder --- -- The full implementation would need to: --- -- 1. Extract the tensors from the existential types --- -- 2. Create index iterators for free and summation indices --- -- 3. Implement the nested loops as described in the blog post --- -- 4. Perform the products and sums according to Einstein notation --- outputTensorA - - -{- -TODO interesting cases of above: -a) one output index, repeated in input: M[i] += M[i, i] -This means that the einsum string determines where to apply it. -Though, notably we've already *created the variables via Elab reflection*, so we should be able to apply them? -I.e. we should be able to 'find all occurences of i' in context, and apply the current value to it? - -Consider -einsum "ii->i" m -vs -einsum "ij->i" m - -If we have a matrix m : Tensor [3, 3] a - -in both cases we'd need to look at the free index i, and then realise that depending on the einsum string we'd need to .. - -In both we'd need to look at the free index i, and then realise that depending on the string we'd need to - ---- - -einsum "isj->ij" t - - - --} - -parameters {a : Type} {auto _ : Num a} {b, i, j, k : Nat} - - matMul : Tensor [i, j] a -> Tensor [j, k] a -> Tensor [i, k] a - matMul m n = %runElab einsum "ij,jk->ik" [m, n] - - batchedMatMul : Tensor [b, i, j] a -> Tensor [b, j, k] a -> Tensor [b, i, k] a - batchedMatMul m n = %runElab einsum "bij,bjk->bik" [m, n] - - -- does not require applicative - outer : Tensor [i] a -> Tensor [j] a -> Tensor [i, j] a - outer v w = %runElab einsum "i,j->ij" [v, w] - - inner : Tensor [i] a -> Tensor [i] a -> Tensor [] a - inner v w = %runElab einsum "i,i->" [v, w] - - elementwise : Tensor [i] a -> Tensor [i] a -> Tensor [i] a - elementwise v w = %runElab einsum "i,i->i" [v, w] - - -- requires Naperian - transpose : Tensor [i, j] a -> Tensor [j, i] a - transpose m = %runElab einsum "ij->ji" [m] - - trace : Tensor [i, i] a -> Tensor [] a - trace m = %runElab einsum "ii->" [m] - - diag : Tensor [i, i] a -> Tensor [i] a - diag m = %runElab einsum "ii->i" [m] - - sum : Tensor [i] a -> Tensor [] a - sum v = %runElab einsum "i->" [v] - - matrixVectorProduct : Tensor [i, j] a -> Tensor [j] a -> Tensor [i] a - matrixVectorProduct m v = %runElab einsum "ij,j->i" [m, v] - - vectorMatrixProduct : Tensor [i] a -> Tensor [i, j] a -> Tensor [j] a - vectorMatrixProduct v m = %runElab einsum "i,ij->j" [v, m] \ No newline at end of file diff --git a/src/Data/Tensor/Einsum/ElabUtilTest.idr b/src/Data/Tensor/Einsum/ElabUtilTest.idr deleted file mode 100644 index bae32aa..0000000 --- a/src/Data/Tensor/Einsum/ElabUtilTest.idr +++ /dev/null @@ -1,231 +0,0 @@ -module Data.Tensor.Einsum.ElabUtilTest - -import Language.Reflection - -%language ElabReflection - --- First, let's see what a quoted expression looks like -exampleQuote : TTImp -exampleQuote = `(x + y) - --- Function to show the structure of a quoted expression -showStructure : TTImp -> String -showStructure tm = show tm - --- Test function to explore elaborator reflection -export -testExpr : IO () -testExpr = putStrLn $ show `(x * y + 3) - -testExpr2 : IO () -testExpr2 = putStrLn $ show `(let t = x * y in t + 3) - ----------------------------------------- ------ Creating variables with chosen names using %runElab ----------------------------------------- - --- Function to create an integer variable with a chosen name and value -createIntVar : String -> Int -> Elab () -createIntVar name value = do - let varName : Name := UN (Basic name) - let varType : TTImp := `(Int) - let varValue : TTImp := IPrimVal EmptyFC (I value) - - -- Create type declaration using correct constructors - let tyDecl : Decl - tyDecl = IClaim (NoFC (MkIClaimData MW Public [] (MkTy EmptyFC (NoFC varName) varType))) - - -- Create function definition with the constant value - let funDef : Decl := IDef EmptyFC varName [PatClause EmptyFC (IVar EmptyFC varName) varValue] - - -- Use declare with a list of declarations - declare [tyDecl, funDef] - -%runElab createIntVar "myVar" 42 -%runElab createIntVar "ieva" 100000000 - -gg : Int -gg = myVar + 40 - --- Test that our generated variables work -export -testGeneratedVars : IO () -testGeneratedVars = do - putStrLn $ "myGeneratedVar = " ++ show myVar - putStrLn $ "anotherVar = " ++ show gg - ----------------------------------------- ------ %runElab on the Right-Hand Side - YES it has a type! ----------------------------------------- - --- %runElab CAN be used on the right-hand side of = and DOES have a type! --- It evaluates to a value at compile time that matches the expected type. - --- More complex example: Use %runElab with check to create a typed value --- This is the pattern from the tutorial you shared -createTypedInt : Int -createTypedInt = %runElab check `(42) - - -createTypedInt2 : Int -createTypedInt2 = %runElab check `(42 + 13) - - ---- Non-dependent variant - -fnn : Int -> Int -fnn = (+ 15) - -createAddFunction : Int -> Int -createAddFunction = %runElab check `(fnn) - -export -testRHSElabFunction : Int -testRHSElabFunction = createAddFunction 10 -- Should be 25 - --- Dependent variant 2 - -public export -einsumZeroDimensional : String -> Type -einsumZeroDimensional "int" = Int -> Int -einsumZeroDimensional "double" = Double -> Double -einsumZeroDimensional _ = Void - -partial -public export -einsumImpl : (xs : String) -> einsumZeroDimensional xs -einsumImpl "int" = %runElab check `(\x => x + 2) -einsumImpl "double" = %runElab check `(\x => x * 17) - - -einsumTest : (xs : String) -> einsumZeroDimensional xs -einsumTest xs = assert_total $ einsumImpl xs - -ggh : Double -ggh = einsumTest "double" 3 - --- Test version that returns Elab type (to test the hypothesis) -partial -einsumImplElab : (xs : String) -> Elab (einsumZeroDimensional xs) -einsumImplElab "->" = check `(\x => einsumElab [x]) -einsumImplElab ",->" = check `(\x, y => einsumElab [x, y]) -einsumImplElab ",,->" = check `(\x, y, z => einsumElab [x, y, z]) - --- rrt : Double --- rrt = einsumTest "asdf" 3 - ---- Dependent variant 3 - -ifn : String -> Type -ifn s = case s of - "ij->ji" => Bool - _ => Int - --- Maybe-based approach: Explicitly handle partiality -dfnMaybe : (s : String) -> ifn s ---dfnMaybe s with (s) --- dfnMaybe s | "ij->ji" = True --- dfnMaybe s | "ii->" = 4 --- dfnMaybe s | _ = ?uiiiiiiiiiiiiiiiiiiii - --- dfnMaybe "ij->ji" = True -- We know ifn "ij->ji" = Bool, so True : Bool --- dfnMaybe "ii->" = 4 -- We know ifn "ii->" = Int, so 4 : Int --- dfnMaybe _ = ?uii -- For unknown strings, we return Nothing - - -{- -testMaybe1 : Maybe Int -testMaybe1 = %runElab check `(dfnMaybe "->") - - --- Another example: create a value by running elaboration that builds an expression -createComputedValue : Int -createComputedValue = %runElab do - let expr = `(5 * 8 + 2) - check expr - --- Example showing %runElab with Elab monad computations on RHS -computeStringLength : Nat -computeStringLength = %runElab do - let str = "Hello, Idris!" - check `(cast {to=Nat} ~(IPrimVal EmptyFC (I (cast (String.length str))))) - --- Test that these work -export -testRHSExamples : IO () -testRHSExamples = do - putStrLn "=== Testing %runElab on Right-Hand Side ===" - putStrLn $ "simpleExample structure: " ++ show simpleExample - putStrLn $ "createTypedInt: " ++ show createTypedInt - putStrLn $ "createAddFunction 10 20: " ++ show testRHSElabFunction - putStrLn $ "createComputedValue: " ++ show createComputedValue - putStrLn $ "computeStringLength: " ++ show computeStringLength - -{- These functions don't work because declareType and defineFunction don't exist --- Function to create a String variable with a chosen name and value -createStringVar : String -> String -> Elab () -createStringVar name value = do - let varName = UN (Basic name) - let varType = `(String) - let varValue = IPrimVal EmptyFC (Str value) - - -- Declare the type - declareType $ ToCubicalTensory EmptyFC EmptyFC varName varType - - -- Define the function with a single clause - let clause = PatClause EmptyFC (IVar EmptyFC varName) varValue - defineFunction $ MkFunDef EmptyFC varName [clause] - --- Create string variables -%runElab createStringVar "greeting" "Hello from generated code!" -%runElab createStringVar "farewell" "Goodbye from Idris!" - --- Test the generated string variables -export -testGeneratedStrings : IO () -testGeneratedStrings = do - putStrLn greeting - putStrLn farewell - --- More flexible function that can create variables of any type (as long as we can quote the value) -createVar : String -> TTImp -> TTImp -> Elab () -createVar name varType varValue = do - let varName = UN (Basic name) - - -- Declare the type - declareType $ ToCubicalTensory EmptyFC EmptyFC varName varType - - -- Define the function with a single clause - let clause = PatClause EmptyFC (IVar EmptyFC varName) varValue - defineFunction $ MkFunDef EmptyFC varName [clause] - --- Create a Bool variable using the flexible function -%runElab createVar "isReady" `(Bool) `(True) - --- Create a List variable -%runElab createVar "myList" `(List Int) `([1, 2, 3, 4, 5]) - --- Test the flexibly generated variables -export -testFlexibleVars : IO () -testFlexibleVars = do - putStrLn $ "isReady = " ++ show isReady - putStrLn $ "myList = " ++ show myList - putStrLn $ "length myList = " ++ show (length myList) --} - - ---- Dependent variant 0 - -rr : List Char -> Type -rr [] = Bool -rr (_ :: _) = Int - -rrd : (xs : List Char) -> rr xs -rrd [] = True -rrd ['a', 'b', 'c'] = 4 -rrd (_ :: _) = 5 -- Match non-empty lists explicitly instead of catch-all - -testList : (xs : List Char) -> rr xs -testList = %runElab check `(rrd) - diff --git a/src/Data/Tensor/Einsum/ElabUtils.idr b/src/Data/Tensor/Einsum/ElabUtils.idr deleted file mode 100644 index f48356d..0000000 --- a/src/Data/Tensor/Einsum/ElabUtils.idr +++ /dev/null @@ -1,100 +0,0 @@ -module Data.Tensor.Einsum.ElabUtils - -import Data.DPair -import Data.List.Quantifiers -import Decidable.Equality -import Language.Reflection - -import Data.Tensor.Einsum.Expr -import Data.Unique -import Misc - -||| Helper function to convert Char to variable name -public export -charToVarName : Char -> Name -charToVarName c = UN (Basic (singleton c)) - -||| Generate [i, j, k] from ['i', 'j', 'k'] -public export -generateShapeVect : List Char -> TTImp -generateShapeVect [] = `([]) -generateShapeVect (x :: xs) = - `(~(IVar EmptyFC (charToVarName x)) :: ~(generateShapeVect xs)) - -||| Generate Tensor [i, j] dtype from shape ['i', 'j'] -public export -generateTensorAType : List Char -> TTImp -generateTensorAType shape = - let shapeVect = generateShapeVect shape - in `(Tensor ~(shapeVect) dtype) - - -||| ['i', 'j', 'k'] -> {dtype : Type} -> Num a => {i, j, k : Nat} -> TensorA [i, j, k] dtype -public export -generateOutputType : List Char -> TTImp -generateOutputType cs = - let outputTensorAType : TTImp := generateTensorAType cs - -- Add implicit {i, j, k : Nat} parameters - withNatParams : TTImp := foldr (\var, acc => - IPi EmptyFC MW ImplicitArg (Just (charToVarName var)) `(Nat) acc) outputTensorAType cs - - -- Add Num a constraint - withNumConstraint : TTImp := IPi EmptyFC MW AutoImplicit Nothing `(Num dtype) withNatParams - - fullType : TTImp := IPi EmptyFC MW ImplicitArg (Just (UN (Basic "dtype"))) `(Type) withNumConstraint - in fullType - - -||| Build einsum function type with tensor shapes and implicit Nat parameters -public export -buildEinsumFunctionType : List Char -> List (List Char) -> List Char -> TTImp -buildEinsumFunctionType uniqueVars inputShapes outputShape = - let - inputTensorATypes : List TTImp := generateTensorAType <$> inputShapes - outputTensorAType : TTImp := generateTensorAType outputShape - - -- Build the main function type: input1 -> input2 -> ... -> output - mainFunctionType : TTImp := foldr (\inputType, acc => - IPi EmptyFC MW ExplicitArg Nothing inputType acc) - outputTensorAType inputTensorATypes - - -- Add implicit {i, j, k : Nat} parameters - withNatParams : TTImp := foldr (\var, acc => - IPi EmptyFC MW ImplicitArg (Just (charToVarName var)) `(Nat) acc) mainFunctionType uniqueVars - - -- Add Num a constraint - withNumConstraint : TTImp := IPi EmptyFC MW AutoImplicit Nothing `(Num dtype) withNatParams - - -- Add implicit {a : Type} parameter FIRST (before everything else) - fullType : TTImp := IPi EmptyFC MW ImplicitArg (Just (UN (Basic "dtype"))) `(Type) withNumConstraint - - in fullType - -||| Generate a function name from the einsum expression -public export -generateFunctionName : String -> String -generateFunctionName einsumStr = "einsum_" ++ withUnderscores where - withUnderscores = replaceString "->" "__" (replaceString "," "_" einsumStr) - - -||| Main function to generate Einsum function type from string -export -partial -generateEinsumType : String -> Elab () -generateEinsumType einsumStr = case parseEinsumString einsumStr of - Left err => fail "Parse error in Einsum string: \{show err}" - Right (MkEinsumExpr inputTy outputTy) => do - let uniqueVars = toList (uniqueJoin inputTy) - fnName = generateFunctionName einsumStr - fnType = buildEinsumFunctionType uniqueVars inputTy (toList outputTy) - - -- Create the type declaration - claimData = MkIClaimData MW Public [] (MkTy EmptyFC (NoFC (UN (Basic fnName))) fnType) - tyDecl = IClaim (MkFCVal EmptyFC claimData) - - declare [tyDecl] - - - - - diff --git a/src/Data/Tensor/Einsum/Expr.idr b/src/Data/Tensor/Einsum/Expr.idr deleted file mode 100644 index 8db8ec1..0000000 --- a/src/Data/Tensor/Einsum/Expr.idr +++ /dev/null @@ -1,283 +0,0 @@ -module Data.Tensor.Einsum.Expr - -import public Data.Vect -import public Data.List -import Data.List.Quantifiers -import Data.DPair -import Data.HashMap -import Decidable.Equality -import Data.String -import Language.Reflection - -import Data.Unique.Vect -import Data.Unique.List -import Data.Tensor --- import Data.Functor.Naperian -import Misc - -%language ElabReflection - --- TODO should axes be ordered? --- For cubical tensors (or generally Naperian) order is generally irrelevant, but for non-cubical ones order matters! CTensor [BinTree, List] a is very different than CTensor [List, BinTree] a? - -||| Correct by construction Einsum expression whose inputs are lists of labels -||| Ensures that -||| a) each output label appears only once -||| b) each output label has appeared in the input -public export -data EinsumExpr : (a : Type) -> DecEq a => Type where - MkEinsumExpr : {a : Type} -> DecEq a => - (inputTy : List (List a)) -> - (outputTy : UniqueList a) -> - {auto prf : outputTy `IsFrom` (toList (uniqueJoin inputTy))} -> - EinsumExpr a - -||| Indices used in the output type -public export -freeIndices : {a : Type} -> DecEq a => EinsumExpr a -> UniqueList a -freeIndices (MkEinsumExpr _ outputTy) = outputTy - -||| Indices used in the input that *do not* appear in the output type -public export -summationIndices : {a : Type} -> DecEq a => EinsumExpr a -> UniqueList a -summationIndices (MkEinsumExpr inputTy outputTy) = fromList $ - complement (uniqueJoin inputTy) outputTy - --- Note that freeIndices + summationIndices = all starting indices - -||| Machinery for pretty-printing Einsum expressions -namespace EinsumToString - ||| If a=Char, we write it as a string - ||| Anything else we add commas between elements and brackets around - public export - labelToString : {a : Type} -> DecEq a => Show a => List a -> String - labelToString {a=Char} xs = pack xs - labelToString xs - = let inter = case a of - String => xs -- necessary so extra quotes aren't added - _ => show <$> xs - in "[" ++ concat (intersperse "," inter) ++ "]" - - public export - inputToString : {a : Type} -> DecEq a => Show a => - (inputTy : List (List a)) -> String - inputToString inputTy = concat $ intersperse "," (labelToString <$> inputTy) - - - public export - outputToString : {a : Type} -> DecEq a => Show a => - UniqueList a -> String - outputToString = labelToString . toList - - public export - toString : DecEq a => Show a => EinsumExpr a -> String - toString (MkEinsumExpr inputTy outputTy) - = (inputToString inputTy) ++ "->" ++ (outputToString outputTy) - - public export - DecEq a => Show a => Show (EinsumExpr a) where - show = toString - - public export - oo : EinsumExpr String - oo = MkEinsumExpr [["i", "j"], ["j", "k"]] ["i", "k"] - - public export - ooChar : EinsumExpr Char - ooChar = MkEinsumExpr [['i', 'j'], ['j', 'k']] ['i', 'k'] - -||| Machinery for parsing a string into an EinsumExpr -||| We focus on parsing into EinsumExpr Char ("ij,jk->ik") -||| Other options are possible, i.e. "[bt,inp],[inp,out]->[bt,out]" -||| But we do not explore them here -namespace Parsing - public export - data EinsumParseError : Type where - EmptyInput : EinsumParseError - MissingArrow : EinsumParseError - ContentBothSidesArrow : EinsumParseError - DuplicateOutputAxis : EinsumParseError - OutputAxisNotInInput : EinsumParseError - MultipleArrows : EinsumParseError - NonAlphaAxis : EinsumParseError - BindingInconsistency : EinsumParseError - - public export - Show EinsumParseError where - show EmptyInput = "Empty input string." - show MissingArrow = "Missing '->' arrow." - show ContentBothSidesArrow = "Content missing on one side of arrow." - show DuplicateOutputAxis = "Duplicate axis in output." - show OutputAxisNotInInput = "Output axis not found in input." - show MultipleArrows = "Multiple '->' arrows found." - show NonAlphaAxis = "Non-alphabetic character found in axis labels. Only [A-Z][a-z] are allowed." - show BindingInconsistency = "Binding inconsistency in axis labels and tensors." -- should this be here? - - public export - parseEinsumString : String -> Either EinsumParseError (EinsumExpr Char) - parseEinsumString str = case str of - "" => Left EmptyInput - _ => case splitString str "->" of - (0 ** _) => Left MissingArrow -- Technically impossible - (1 ** _) => Left ContentBothSidesArrow - (2 ** [left, right]) => - let xs : Vect _ String := snd (splitString left ",") - inputLabels : List (List Char) := unpack <$> (toList xs) - outputLabels : List Char := unpack right - in case all (all isAlpha) inputLabels of - False => Left NonAlphaAxis - True => case all isAlphaNum outputLabels of - False => Left NonAlphaAxis - True => case fromListMaybe outputLabels of - Nothing => Left DuplicateOutputAxis - Just outputTy => - case checkAllInInput outputTy (uniqueJoin inputLabels) of - Nothing => Left OutputAxisNotInInput - Just prf => Right (MkEinsumExpr inputLabels outputTy {prf = (IndeedItIs {prf=prf})}) - (_ ** _) => Left MultipleArrows - where - -- Check if all output labels appear in input labels and provide proof - checkAllInInput : (outputTy : UniqueList Char) -> - (inputChars : UniqueList Char) -> - Maybe (All (\x => Elem x (toList inputChars)) outputTy) - checkAllInInput [] inputChars = Just [] - checkAllInInput (x :: xs) inputChars = - case isElem x (toList inputChars) of - Yes prf => case checkAllInInput xs inputChars of - Just restPrf => Just (prf :: restPrf) - Nothing => Nothing - No _ => Nothing - - -||| Correct by construction Einsum expression whose input is a string -||| together with a proof that it correctly parses into a valid expression -public export -data EinsumStrExpr : Type where - EinsumChar : (einsumExprString : String) -> - {einsumExpr : EinsumExpr Char} -> - {auto prf : parseEinsumString einsumExprString = Right einsumExpr} -> - EinsumStrExpr - -||| Project out the input of an einsum expression -public export -inputTyProj : EinsumStrExpr -> List (List Char) -inputTyProj (EinsumChar _ {einsumExpr = (MkEinsumExpr inputTy _)}) = inputTy - -||| Project out the output of an einsum expression -public export -outputTyProj : EinsumStrExpr -> UniqueList Char -outputTyProj (EinsumChar _ {einsumExpr = (MkEinsumExpr _ outputTy)}) = outputTy - -||| Number of input tensors -public export -numInputs : EinsumStrExpr -> Nat -numInputs (EinsumChar _ {einsumExpr = (MkEinsumExpr inputTy _)}) = length inputTy - - -esTest : EinsumStrExpr -esTest = EinsumChar "ij,jk->ik" - -esTest2 : EinsumStrExpr -esTest2 = EinsumChar "ij,jk->ij" - -failing - esFail : EinsumStrExpr - esFail = EinsumChar "ij->xx" - --- ||| Inductive representation of a heterogeneous list of variable-shape tensors --- ||| It exposes the shape information the type, and allows us to use the usual --- ||| list syntax like [t1, t2] with tensors of different shapes --- ||| This means that shape parameter can be inferred by the typechecker, instead --- ||| of needing to be manually supplied --- ||| For instance, f : TensorList shapes -> ... can consume f [m, n, k] where these are tensors of all different shaeps --- public export --- data CTensorList : List (List Cont) -> Type -> Type where --- Nil : CTensorList [] a --- (::) : CTensor sh a -> CTensorList shapes a -> CTensorList (sh :: shapes) a - -public export -inputType : EinsumStrExpr -> Type -inputType s = Vect (numInputs s) ?inputType_rhs - -public export -outputType : (s : EinsumStrExpr) -> (inputType : inputType s) -> Type - -public export -einsum : (s : EinsumStrExpr) -> (it : inputType s) -> outputType s it - --- TODO something also about checking for transposes? --- "ij->ji" counts as a transpose, but what about "i,j->ji"? Or "ij,ji->ij"? Transpose in which variable, that's important? - -{- -We need -1. A map einsumType : EinsumStrExpr -> Type which produces a list of input tensors of appropriate types - -"ij,jk->ik" -> Tensor [i, j] Double, Tensor [j, k] Double - -But technically, all we need at first is to produce a Vect numInputs (List Nat)? -Then from these we can produce the tensors - -then from -einsum : (s : EinsumStrExpr) -> (inputType : einsumType s) -> outputType s inputTensors - - - - --} - - - -{- -Elab cons: can't get rid of %runElab at call sites (maybe possible with %macro?) -Elab con2: can't use it for gradient computation since it has to be a constant at compile time - -Manual: can't generate implicit variables? Not sure how to do it at all? - -} - - - - - -t1 : Tensor [2, 3] ["i", "j"] Double -t1 = ># [ [1, 2, 3], [4, 5, 6] ] - -t2 : Tensor [3, 4] ["i", "j"] Double -t2 = ># [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12] ] - -||| The name for an axis is an arbitrary string, usually a single character -AxisName : Type -AxisName = Char - -AxisBinding : Type -AxisBinding = HashMap AxisName Nat - --- simpleSum : {i : Nat} -> Tensor [i] Double -> Tensor [] Double --- simpleSum x = ToCubicalTensor $ TZ $ foldr (+) 0 x - - --- simpleTrace : {i : Nat} -> Tensor [i, i] Double -> Tensor [] Double --- simpleTrace x = ToCubicalTensor $ TZ $ foldr (+) 0 x --- --- simpleDiagonal : {i : Nat} -> Tensor [i, i] Double -> Tensor [i] Double --- simpleDiagonal x = ToCubicalTensor $ TS $ tabulate (\k => TZ $ x @@@ [k, k]) - --- nestedFold : {i, j : Nat} -> Tensor [i, j] Double -> Tensor [] Double --- nestedFold x = ToCubicalTensor $ TZ $ foldr (+) 0 x - - --- public export --- uniqueJoinVect : {nInputs : Nat} -> Vect nInputs String -> UniqueList Char --- uniqueJoinVect xs = uniqueJoin $ (unpack <$>) (toList xs) --- --- data EinsumStrExpr' : Type where --- EinsumChar' : (einsumExpr : String) -> --- {left, right : String} -> --- {auto prf : splitString einsumExpr "->" = (2 ** [left, right])} -> --- {nInputs : Nat} -> --- {xs : Vect nInputs String} -> --- {auto prf_left : splitString left "," = (nInputs ** xs)} -> --- {outputTy : UniqueList Char} -> --- {auto prf_unique : fromListMaybe (unpack right) = Just outputTy} -> --- {auto prf_from_input : All (\x => Elem x (toList (uniqueJoinVect xs))) outputTy} -> --- EinsumStrExpr' - diff --git a/src/Data/Tensor/Einsum/SimpleTest.idr b/src/Data/Tensor/Einsum/SimpleTest.idr deleted file mode 100644 index 3fcf596..0000000 --- a/src/Data/Tensor/Einsum/SimpleTest.idr +++ /dev/null @@ -1,65 +0,0 @@ -module Data.Tensor.Einsum.SimpleTest - -import Data.Tensor.Einsum.VariableFind -import Language.Reflection - -%default total -%language ElabReflection - --- Simple test function that shows how to use VariableFind --- This function has variables in scope that we can test with -export -testVariableFind : Int -> String -> Bool -> () -testVariableFind x y z = %runElab do - -- Now we have variables x, y, z in scope - -- Test finding variables by name - findVariableByName "x" - findVariableByName "y" - findVariableByName "z" - - -- Test getting type information as strings - typeX <- getVariableTypeString "x" - typeY <- getVariableTypeString "y" - typeZ <- getVariableTypeString "z" - - logMsg "info" 0 $ "Type of x: \{typeX}" - logMsg "info" 0 $ "Type of y: \{typeY}" - logMsg "info" 0 $ "Type of z: \{typeZ}" - --- Test function for variable references -export -testVariableReferences : Int -> String -> Bool -> () -testVariableReferences x y z = %runElab do - -- Get all local variables - localVars <- localVars - case localVars of - [] => logMsg "info" 0 "No local variables found" - (firstVar :: _) => do - -- Create a variable reference - let varRef = varRef firstVar - - -- Test processing the variable reference - result <- processVarRef varRef - logMsg "info" 0 $ "Processed variable: \{result}" - - -- Test getting the type of the variable reference - varType <- getVarRefType varRef - logMsg "info" 0 $ "Variable type: \{show varType}" - --- Simple test function with just one variable -export -simpleTest : Int -> () -simpleTest x = %runElab do - -- Test finding the variable x - findVariableByName "x" - - -- Test getting its type - typeStr <- getVariableTypeString "x" - logMsg "info" 0 $ "x has type: \{typeStr}" - --- Test function that tries to find a non-existent variable -export -testNonExistent : Int -> () -testNonExistent x = %runElab do - -- This should fail - findVariableByName "nonExistent" diff --git a/src/Data/Tensor/Einsum/TestVariableFind.idr b/src/Data/Tensor/Einsum/TestVariableFind.idr deleted file mode 100644 index 841de49..0000000 --- a/src/Data/Tensor/Einsum/TestVariableFind.idr +++ /dev/null @@ -1,64 +0,0 @@ -module Data.Tensor.Einsum.TestVariableFind - -import Data.Tensor.Einsum.VariableFind -import Language.Reflection - -%language ElabReflection - --- Test function that demonstrates how to use VariableFind --- This function has variables in scope that we can test with -export -testVariableFind : Int -> String -> Bool -> Elab () -testVariableFind x y z = do - -- Now we have variables x, y, z in scope - -- Test finding variables by name - findVariableByName "x" - findVariableByName "y" - findVariableByName "z" - - -- Test getting type information as strings - typeX <- getVariableTypeString "x" - typeY <- getVariableTypeString "y" - typeZ <- getVariableTypeString "z" - - logMsg "info" 0 $ "Type of x: \{typeX}" - logMsg "info" 0 $ "Type of y: \{typeY}" - logMsg "info" 0 $ "Type of z: \{typeZ}" - --- Test function for variable references -export -testVariableReferences : Int -> String -> Bool -> Elab () -testVariableReferences x y z = do - -- Get all local variables - localVars <- localVars - case localVars of - [] => logMsg "info" 0 "No local variables found" - (firstVar :: _) => do - -- Create a variable reference - let varRef = varRef firstVar - - -- Test processing the variable reference - result <- processVarRef varRef - logMsg "info" 0 $ "Processed variable: \{result}" - - -- Test getting the type of the variable reference - varType <- getVarRefType varRef - logMsg "info" 0 $ "Variable type: \{show varType}" - --- Simple test function with just one variable -export -simpleTest : Int -> Elab () -simpleTest x = do - -- Test finding the variable x - findVariableByName "x" - - -- Test getting its type - typeStr <- getVariableTypeString "x" - logMsg "info" 0 $ "x has type: \{typeStr}" - --- Test function that tries to find a non-existent variable -export -testNonExistent : Int -> Elab () -testNonExistent x = do - -- This should fail - findVariableByName "nonExistent" diff --git a/src/Data/Tensor/Einsum/VariableFind.idr b/src/Data/Tensor/Einsum/VariableFind.idr deleted file mode 100644 index 5352d17..0000000 --- a/src/Data/Tensor/Einsum/VariableFind.idr +++ /dev/null @@ -1,112 +0,0 @@ -module Data.Tensor.Einsum.VariableFind - -import Language.Reflection -import Data.List -import Data.String - -import Data.Tensor - -%language ElabReflection - --- Find variable by name and get its type as a string -export -getVariableTypeString : String -> Elab String -getVariableTypeString varName = do - localVars <- localVars - case filter (\name => show name == varName) localVars of - (name :: _) => do - -- Found in local scope - varType <- getLocalType name - pure $ show varType - [] => do - -- Not found locally, try global scope - let globalName = UN (Basic varName) - globalTypes <- getType globalName - case globalTypes of - [] => fail $ "Variable '\{varName}' not found in local or global scope" - ((name, varType) :: _) => do - pure $ show varType - - --- Helper function to extract names from a shape list -extractFromShapeList : TTImp -> List String -extractFromShapeList (IVar _ name) = [show name] -extractFromShapeList (IApp _ (IVar _ name) rest) = show name :: extractFromShapeList rest -extractFromShapeList (IApp _ f x) = extractFromShapeList f ++ extractFromShapeList x -extractFromShapeList _ = [] - --- Helper function to recursively search for CTensor in the type -findCTensorInType : TTImp -> List String -findCTensorInType (IApp _ (IApp _ (IVar _ name) shapeList) _) = - if show name == "CTensor" then - extractFromShapeList shapeList - else [] -findCTensorInType (IApp _ f x) = - findCTensorInType f ++ findCTensorInType x -findCTensorInType _ = [] - --- Helper function to extract container names from a TTImp type -extractContainerNamesFromType : TTImp -> List String -extractContainerNamesFromType (IApp _ (IApp _ (IVar _ name) shapeList) _) = - if show name == "CTensor" then extractFromShapeList shapeList else [] -extractContainerNamesFromType (IApp _ f x) = - extractContainerNamesFromType f ++ extractContainerNamesFromType x -extractContainerNamesFromType (IVar _ name) = - if show name == "CTensor" then [] else [show name] -extractContainerNamesFromType _ = [] - -||| Get variable references of all containers making up a tensor -||| I.e. given c1 : Cont and c2 : Cont -||| getTensorContainerRefs [c1, c2] will return [c1, c2] -getTensorContainerRefs : CTensor shape a -> Elab (List String) -getTensorContainerRefs tensor = do - -- Get the type of the tensor using quote - tensorType <- quote tensor - - -- Debug: log the actual type structure - logMsg "info" 0 $ "Tensor type: \{show tensorType}" - - -- Try different parsing approaches - let result1 = findCTensorInType tensorType - let result2 = extractContainerNamesFromType tensorType - let result3 = extractAllNames tensorType - - logMsg "info" 0 $ "Method 1 (findCTensorInType): \{show result1}" - logMsg "info" 0 $ "Method 2 (extractContainerNamesFromType): \{show result2}" - logMsg "info" 0 $ "Method 3 (extractAllNames): \{show result3}" - - -- Return the first non-empty result - case result1 of - [] => case result2 of - [] => pure result3 - xs => pure xs - xs => pure xs - --- Extract all names from any TTImp structure -extractAllNames : TTImp -> List String -extractAllNames (IVar _ name) = [show name] -extractAllNames (IApp _ f x) = extractAllNames f ++ extractAllNames x -extractAllNames _ = [] - - -myVar2 : Int -myVar2 = 7 - -testt : String -testt = let myVar = 3 in %runElab getVariableTypeString "myVar" - -testt2 : String -testt2 = %runElab getVariableTypeString "myVar2" - - -t1 : CTensor [BinTree, Vect 2] Double -t1 = ># Node [1, 2] (Leaf [3,4]) (Leaf [4,5]) - -ns : List String -ns = %runElab getTensorContainerRefs t1 - --- Let's also test with a simpler case to see the structure -simpleTest : List String -simpleTest = %runElab do - let simpleTensor : CTensor [BinTree] Int = ?hole - getTensorContainerRefs simpleTensor \ No newline at end of file diff --git a/src/Data/Tensor/Einsum/einsumnotes.txt b/src/Data/Tensor/Einsum/einsumnotes.txt deleted file mode 100644 index 60e8977..0000000 --- a/src/Data/Tensor/Einsum/einsumnotes.txt +++ /dev/null @@ -1,123 +0,0 @@ -Einstein Summation in Numpy: https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/ -as_strided and sum is all you need: https://jott.live/markdown/as_strided -Basic guide to einsum: https://ajcr.net/Basic-guide-to-einsum/ -Named Tensor Notation: https://arxiv.org/abs/2102.13196 -Functional Einsum: https://www.cambridge.org/core/journals/journal-of-functional-programming/article/domainspecific-tensor-languages/19B95794B66C66E117DFCFC7A21E22D5 -Optimal tensor contraction: https://docs.pytorch.org/docs/stable/generated/torch.einsum.html -Einsum in Depth: https://einsum.joelburget.com/ -The Tensor Cookbook: https://tensorcookbook.com/ -Syntax and semantics of Einsum: https://arxiv.org/pdf/2509.20020v1 - -Einsum for general function broadcasting? - -I.e. if I have a function Tensor [ieva] a -> Tensor [ieva] a, what's the most convenient way to broadcast it to -Tensor [batch, ieva] a -> Tensor [batch, ieva] a? - ----------------------------------------- ------ Einsum examples: ----------------------------------------- - -Are all of these coverred by the idea of "parallel product + hiding"? - -Transpose -> einsum("ij->ji") -Sum -> einsum("ij->") (equals to a fold!) -Column sum -> einsum("ij->j") -Row sum -> einsum("ij->i") -Matrix-vector product -> einsum("ij,j->i") -Matrix-matrix product -> einsum("ij,jk->ik") -Dot product (Vector) -> einsum("i,i->") -Dot product (Matrix) -> einsum("ij,ij->") -Outer product -> einsum("i,j->ij") - -x : TensorA [3, 3, 3] Double -Einsum "iii->i" x = view main diagonal -Einsum "iii->" x = trace (sum elements along diagonal) -Einsum "ijk->" x = sum all elements -Einsum "ijk->kji" x = transpose first and last axis - -y : TensorA [3, 4] Double -Einsum "ii->" x = Invalid -> x is not of the right type -Einsum "ii->ii", x = Invalid -> output subscript included multiple times - -Errors: -* Output subscript can't be included multiple times! -* Every output subscript has to appear in the input - - ----------------------------------------- ------ Interface requirements ----------------------------------------- - Einsum seems to be formed out of a few operations: - - Transpose -> Covered by Naperian - - Sum -> Covered by Algebra - - Dot product -> Covered by Applicative - - Outer product -> No requirement? - - anything else? - ----------------------------------------- ------ Misc thoughts: ----------------------------------------- - -Monad comprehensions? - -TODO perhaps we need Traversables to define Einsum for loops?? -Traversable connection to Applicative: -https://x.com/khoiiiind/status/1925526689339379832 - -Product distributes over sum -Traversable distributes over Applicative - - -Is einsum abount binding? - einsum("ij,jk->ik", M, N) - - Here we bind the tensor M to ij, and N to jk - -Q: SCOPING: Why should scoping of Einsum names be local? -Should it perhaps be global instead? - -Maybe it doesn't matter that we have Einsum "ii" (TensorA [3, 4] a), -perhaps if we want to contract, 3 and 4 should...what? be the same variable? - - -Should einsum work for generalised tensors? ----------------------------------------- ------ Einsum algorithm: ----------------------------------------- -In this example, fix: -shapeX = [100, 4, 5] -shapeY = [100, 5, 6] -x : TensorA shapeX Double -y : TensorA shapeY Double -Einsum "bij,bjk->ik" x y - -Step 1: Parsing, variable binding, and error checking -We want to first collect all the unique axis names 'b', 'i', 'j', 'l' and store tham as a axisNames : UniqueVect m AxisName - -So we want -"b" -> shapeX[0], shapeY[0] -"i" -> shapeX[1] -"j" -> shapeX[2], shapeY[1] -"k" -> shapeY[2] - -AxisName -> List (xs : Vect n Nat, Fin n) - -This is the part where we also check for errors, and inconsistent axis naming - -Step 2) When we have this, there are many tensors we can get out, depending on what the output string and output tensor is - - - - -We onlOnly look at the input string and shapes, i.e. "bij,bjk" shapeX shapeY -and use it to do parsing/error checking, and performing of variable binding. - - -TODO What do we do about ellipsis? - -Ellipsis can either be -a) on the left side of each term -b) on the right side of each term -c) in the middle, in the case of a trace (einsum("i...i->...i", x))? (todo think about this) - -To enable and control broadcasting, use an ellipsis. Default NumPy-style broadcasting is done by adding an ellipsis to the left of each term, like np.einsum('...ii->...i', a). np.einsum('...i->...', a) is like np.sum(a, axis=-1) for array a of any shape. To take the trace along the first and last axes, you can do np.einsum('i...i', a), or to do a matrix-matrix product with the left-most indices instead of rightmost, one can do np.einsum('ij...,jk...->ik...', a, b). \ No newline at end of file diff --git a/src/Data/Tensor/Shape/Axis.idr b/src/Data/Tensor/Shape/Axis.idr new file mode 100644 index 0000000..5f72024 --- /dev/null +++ b/src/Data/Tensor/Shape/Axis.idr @@ -0,0 +1,107 @@ +module Data.Tensor.Shape.Axis + +import Data.Vect.Quantifiers + +import Data.Container.Base +import Misc + +||| The name for an axis is an arbitrary string +public export +AxisName : Type +AxisName = String + +||| An axis is a container together with its name +public export +record Axis where + constructor (~>) + name : AxisName + cont : Cont + +public export infixr 0 ~> -- Constructor for container-based axes +public export infixr 0 ~~> -- 'Constructor' for cubical axes + +public export +rename : Axis -> AxisName -> Axis +rename a str = str ~> a.cont + +||| In some cases we TensorType might need to assign a default name to an axis, +||| one which is internal and will not be exposed to the user. +||| This is the default name for such cases +public export +TTInternalName : AxisName +TTInternalName = "__tensortype_tempaxis__" + +namespace Cubical + ||| A "constructor" for cubical axes + public export + (~~>) : AxisName -> Nat -> Axis + (~~>) axisName n = axisName ~> Vect n + + ||| Follows the pattern of `IsCubical` from `Base.Object.Definition` + public export + data IsCubical : Axis -> Type where + MkIsCubical : (name : AxisName) -> (n : Nat) -> IsCubical (name ~~> n) + + ||| Evidence of axis cubicality -> evidence of underlying container cubicality + public export + toContCubical : {0 a : Axis} -> IsCubical a -> IsCubical a.cont + toContCubical (MkIsCubical _ n) = MkIsCubical n + + ||| Extract the dimension from IsCubical with axis implicit + public export + dimHelper : {0 a : Axis} -> IsCubical a -> Nat + dimHelper (MkIsCubical _ n) = n + + ||| Extract the dimension from an axis which we know is cubical + public export + dim : (0 a : Axis) -> IsCubical a => Nat + dim _ @{ic} = dimHelper ic + + ||| Extract the dimensions of cubical axes, with shape implicit + public export + dimsHelper : {0 shape : Vect r Axis} -> + All IsCubical shape -> List Nat + dimsHelper [] = [] + dimsHelper (ic :: ns) = dimHelper ic :: dimsHelper ns + + ||| Extract the dimensions of cubical axes, with shape explicit + public export + dims : (0 shape : Vect r Axis) -> All IsCubical shape => List Nat + dims _ @{ac} = dimsHelper ac + + ||| Product of all the dimensions of a cubical tensors, i.e. its size + public export + size : (0 shape : Vect r Axis) -> (ac : All IsCubical shape) => Nat + size shape = prod (dims shape) + +namespace Naperian + ||| Follows the pattern of `IsNaperian` from `Base.Object.Definition` + public export + data IsNaperian : Axis -> Type where + MkIsNaperian : (name : AxisName) -> (pos : Type) -> + IsNaperian (name ~> Nap pos) + + ||| Evidence of axis being Naperian -> evidence of container being Naperian + %hint + public export + toContNaperian : {0 a : Axis} -> IsNaperian a -> IsNaperian a.cont + toContNaperian (MkIsNaperian _ pos) = MkIsNaperian pos + + ||| Extract the position type from IsNaperian with axis implicit + public export + LogHelper : {0 a : Axis} -> IsNaperian a => Type + LogHelper @{MkIsNaperian _ pos} = pos + + ||| Extract the position type from an axis which we know is Naperian + public export + Log : (0 a : Axis) -> IsNaperian a => Type + Log a @{inn} = LogHelper @{inn} + +namespace IsConcrete + public export + data IsConcrete : Axis -> Type where + MkIsConcrete : (name : AxisName) -> IsConcrete c -> IsConcrete (name ~> c) + + public export + toContConcrete : {0 a : Axis} -> IsConcrete a -> IsConcrete a.cont + toContConcrete (MkIsConcrete _ ic) = ic \ No newline at end of file diff --git a/src/Data/Tensor/Shape/Shape.idr b/src/Data/Tensor/Shape/Shape.idr new file mode 100644 index 0000000..6359148 --- /dev/null +++ b/src/Data/Tensor/Shape/Shape.idr @@ -0,0 +1,503 @@ +module Data.Tensor.Shape.Shape + +import public Decidable.Equality +import Data.Vect.Elem +import Data.Vect.Quantifiers + +import Data.Container.Base +import Data.Tensor.Shape.Axis +import Misc + +{------------------------------------------------------------------------------- +{------------------------------------------------------------------------------- + +~~~~~~~~~~~~~~~ +Design choices: +~~~~~~~~~~~~~~~ + +1) Persistent axis names. + +Instead of transient axis names - where the names are bound within a function such as `np.einsum("ij->ji", m)` and erased after the said function is evaluated - axis names are a part of the tensor shape, and persist with the lifetime of the said tensor. + +2) Axis declarations persist globally, but are only checked for consistency at call sites. + +In a proper tensor programming language we'd prevent declaration of inconsistent/duplicate axis names to begin with. +Here we opt for a more pragmatic approach: checking consistency locally at each call site, rather than at declaration sites. + +3) Duplicate axis names within a tensor are allowed, as long as the names are consistent. + +Otherwise it would not be clear how to take the diagonal/trace of a matrix while referring only to the axes: they'd have to have different names. + +4) Tensor contractions allow duplicate names both in the input and in the output, again, as long as they're consistent. Names in output which haven't appeared in the input are also allowed. + +Duplicate input means zipping/reduction, duplicate output means diagonalisation. +This is different from what standard `einsum` allows: it does not permit duplicate names in the output. It also does not not allow names to appear in output which haven't appeared in the input, because their size wouldn't be known. Becuase for us `Axis` refers to both name and size, this works. + +~~~~~~~~~~~~~~~ +Design choices: +~~~~~~~~~~~~~~~ + +Two axes are consistent if they either have different names, or same name and same underlying container. Conversely, they're inconsistent if they're called the same, but refer to the different underlying container. + +It is possible to have inconsistent axes bound declared within the same scope. +Consistency is checked only at call sites. + +Alternatively if we were building a programming languge we'd check consistency with each declaration. That is, writing something like: +```idris +BatchSize1 : Axis +BatchSize1 = "batchSize" ~> Vect 128 + +BatchSize2 : Axis +BatchSize2 = "batchSize" ~> Vect 129 +``` +would throw an error on the line `BatchSize2 = ...` because we're redeclaring "batchSize" which already exists. + +------------------------------------------- + +Similar projects/ideas: +* XArray: https://docs.xarray.dev/en/stable/ (persistent axis names) +* Haliax: https://github.com/marin-community/haliax + +Useful documentation: +* https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/ +* https://nlp.seas.harvard.edu/NamedTensor +* https://einsum.joelburget.com/ + + +-------------------------------------------------------------------------------} +-------------------------------------------------------------------------------} + + +mutual + ||| Datatype defining the shape of a tensor + ||| It is a vector of consistently named axes + public export + data TensorShape : (rank : Nat) ->Type where + Nil : TensorShape 0 + (::) : (a : Axis) -> (as : TensorShape k) -> + (axisConsistent : a `ConsistentWith` as) => + TensorShape (S k) + + ||| An axis is named consistently with respect to an existing tensor shape + ||| if its name does not appear in the shape, or it if it appears and + ||| references the same container + public export + data ConsistentWith : Axis -> TensorShape k -> Type where + NewAxis : {0 a : Axis} -> {0 as : TensorShape k} -> + NotElem a.name as -> + a `ConsistentWith` as + ExistingAxis : {0 a : Axis} -> {0 as : TensorShape k} -> + (e : Elem a.name as) -> + index as a.name = a.cont -> + a `ConsistentWith` as + + ||| A proof that an axis name `a` is found in a tensor shape `as` + ||| Notably this could have been implemented to check whether the + ||| *entire axis* is in vector, and not just the name. + ||| But since a tensor shape is more akin to a dictionary, keeping this form + public export + data Elem : (axisName : AxisName) -> (as : TensorShape rank) -> Type where + Here : {0 as : TensorShape rank} -> + a `ConsistentWith` as => + axisName = a.name => + Elem axisName (a :: as) + There : {0 a : Axis} -> {0 as : TensorShape rank} -> + a `ConsistentWith` as => + (elem : Elem axisName as) => + Elem axisName (a :: as) + + ||| A proof that an axis name `a` is not found in a tensor shape `as` + public export + data NotElem : (axisElem : AxisName) -> (as : TensorShape rank) -> Type where + NotInEmpty : NotElem axisElem [] + NotInNonEmpty : {0 axisElem : AxisName} -> {0 a : Axis} -> + {0 as : TensorShape rank} -> + (neq : IsNo (decEq axisElem a.name)) -> + (notElem : NotElem axisElem as) => + (a `ConsistentWith` as) -> + NotElem axisElem (a :: as) + + + ||| Indexing into a tensor shape. + ||| Could be many occurences - recovers the one provided by `isElem` + public export + index : (shape : TensorShape rank) -> + (axisName : AxisName) -> + (isElem : Elem axisName shape) => + Cont + index (a :: _) axisName @{Here} = a.cont + index (_ :: as) axisName @{There} = index as axisName + +||| to get rid of believe_me this might need to be put in a mutual block too +public export +rename : (shape : TensorShape rank) -> + (axisName : AxisName) -> + (newAxisName : AxisName) -> + TensorShape rank +rename [] _ _ = [] +rename (a :: as) axisName newAxisName + = (::) (applyWhen (axisName == a.name) (flip rename newAxisName) a) (rename as axisName newAxisName) @{believe_me "consistentAfterRenaming"} + +namespace RenameByIndex + ||| to get rid of believe_me this might need to be put in a mutual block too + public export + rename : (shape : TensorShape rank) -> + (axisIndex : Fin rank) -> + (newAxisName : AxisName) -> + TensorShape rank + rename (a :: as) FZ newAxisName + = (::) (newAxisName ~> a.cont) as @{believe_me "consistentAfterRenamingByIndex"} + rename (a :: as) (FS axisIndex) newAxisName + = (::) a (rename as axisIndex newAxisName) @{believe_me "consistentAfterRenamingByIndex"} + +||| These are quantifiers for the axes +||| Sometimes we need to explicitly quantify over something involving a name +namespace Quantifiers + public export + data All : (p : Axis -> Type) -> TensorShape k -> Type where + Nil : All p [] + (::) : {0 a : Axis} -> {0 as : TensorShape k} -> + p a -> All p as -> + a `ConsistentWith` as => + All p (a :: as) + + public export + data Any : (p : Axis -> Type) -> TensorShape k -> Type where + Here : {0 a : Axis} -> {0 as : TensorShape k} -> + a `ConsistentWith` as => + p a -> Any p (a :: as) + There : {0 a : Axis} -> {0 as : TensorShape k} -> + a `ConsistentWith` as => + Any p as -> Any p (a :: as) + + ||| These quantifiers are specifically for underlying containers, not axes + ||| In theory we should be able to overload, but what do we do with + ||| purposefully overloaded instances `IsCubical`, `IsNaperian` whose + ||| existence will prevent elaboration? + namespace QuantifierOnContainers + public export + data AllC : (p : Cont -> Type) -> TensorShape k -> Type where + Nil : AllC p [] + (::) : {0 a : Axis} -> {0 as : TensorShape k} -> + p a.cont -> AllC p as -> + a `ConsistentWith` as => + AllC p (a :: as) + + public export + data AnyC : (p : Cont -> Type) -> TensorShape k -> Type where + Here : {0 a : Axis} -> {0 as : TensorShape k} -> + a `ConsistentWith` as => + p a.cont -> AnyC p (a :: as) + There : {0 a : Axis} -> {0 as : TensorShape k} -> + a `ConsistentWith` as => + AnyC p as -> AnyC p (a :: as) + +public export +tensorShapesConsistent : TensorShape k -> TensorShape k' -> Type +tensorShapesConsistent s1 s2 = All (\a => a `ConsistentWith` s2) s1 + +||| (::) here pattern matches on the proof `axisConsistent` and discards it +public export +toVect : TensorShape k -> Vect k Axis +toVect [] = [] +toVect (a :: as) = a :: toVect as + +public export +toList : TensorShape k -> List Axis +toList [] = [] +toList (a :: as) = a :: toList as + +||| Convenience function, turns it also into a list +||| Because `Tensor` from `Data.Container` uses lists with tensors +public export +conts : TensorShape k -> List Cont +conts ts = cont <$> toList ts + +||| Renaming the shape preserves the underlying data +public export +renamePreservesConts : (shape : TensorShape rank) -> + (axisName : AxisName) -> + (newAxisName : AxisName) -> + conts (rename shape axisName newAxisName) = conts shape +renamePreservesConts [] _ _ = Refl +renamePreservesConts (a :: as) axisName newAxisName with (axisName == a.name) + _ | True = cong (a.cont ::) (renamePreservesConts as axisName newAxisName) + _ | False = cong (a.cont ::) (renamePreservesConts as axisName newAxisName) + +namespace RenameByIndex + ||| Renaming a shape at a specific index preserves the underlying containers. + public export + renamePreservesConts : (shape : TensorShape rank) -> + (axisIndex : Fin rank) -> + (newAxisName : AxisName) -> + conts (rename shape axisIndex newAxisName) = conts shape + renamePreservesConts (a :: as) FZ newAxisName = Refl + renamePreservesConts (a :: as) (FS axisIndex) newAxisName + = cong (a.cont ::) (renamePreservesConts as axisIndex newAxisName) + +||| Names of the axes in a tensor shape +public export +axisNames : TensorShape k -> Vect k AxisName +axisNames ts = name <$> toVect ts + +||| Sizes of the axes in a tensor shape +public export +axisSizes : TensorShape k -> Vect k Cont +axisSizes ts = cont <$> toVect ts + +||| Size of a tensor shape, i.e. its number of elements +public export +size : (shape : TensorShape k) -> All IsCubical (conts shape) => Nat +size shape = size (conts shape) + +||| Cubicality evidence for a tensor shape, using `Either` so that +||| auto-search tries `Left prf` (positive evidence) before `Right ()` +||| (fallback). Standard `Maybe` can't be used because `Nothing` is +||| listed first in its definition and auto-search would always pick it. +public export +0 TensorCubEvidence : TensorShape k -> Type +TensorCubEvidence shape = Either (All IsCubical shape) () + + + + + +namespace Unique + ||| A proof that an axis name only appears once in the tensor shape + public export + data UniqueElem : AxisName -> TensorShape rank -> Type where + Here : {0 as : TensorShape rank} -> + axisName = ax.name => + NotElem axisName as => + ax `ConsistentWith` as => + UniqueElem axisName (ax :: as) + There : {0 ax : Axis} -> {0 as : TensorShape rank} -> + IsSucc rank => + (uniqueElem : UniqueElem axisName as) => + (neq : IsNo (decEq axisName ax.name)) => + ax `ConsistentWith` as => + UniqueElem axisName (ax :: as) + + ||| Forgets that the axis name only appears once in the tensor shape + public export + forgetUnique : {as : TensorShape rank} -> + UniqueElem axisName as -> + Elem axisName as + forgetUnique {as = (a :: as)} Here = Here + forgetUnique {as = (a :: as)} (There {uniqueElem=elem}) + = There {elem=forgetUnique elem} + + public export + index : (shape : TensorShape rank) -> + (axisName : AxisName) -> + (uniqueElem : UniqueElem axisName shape) => + Cont + index (a :: _) axisName @{Here} = a.cont + index (_ :: as) axisName @{There} = index as axisName + + mutual + public export + removeAxis : {rank : Nat} -> + (toRemove : AxisName) -> + (shape : TensorShape (S rank)) -> + (is : UniqueElem toRemove shape) => + TensorShape rank + removeAxis toRemove (_ :: as) @{Here} = as + removeAxis toRemove (a :: as) @{There @{ItIsSucc}} + = let cProof = consistentAfterRemoving a as toRemove + in a :: removeAxis toRemove as + + public export + consistentAfterRemoving : {rank : Nat} -> + (a : Axis) -> (as : TensorShape (S rank)) -> + a `ConsistentWith` as => + (toRemove : AxisName) -> + (uElem : UniqueElem toRemove as) => + a `ConsistentWith` (removeAxis toRemove as) + consistentAfterRemoving = believe_me "consistentAfterRemoving" + +notElemExample1 : NotElem "i" ["g" ~> List, "j" ~> BinTree] +notElemExample1 = %search + +tensorShapeTest1 : TensorShape 2 +tensorShapeTest1 = ["batchSize" ~> Vect 128, "seqLen" ~> List] + +tensorShapeTest2 : TensorShape 3 +tensorShapeTest2 + = ["batchSize" ~> Vect 128, "seqLen" ~> List, "batchSize" ~> Vect 128] + +failing + tensorShapeTest3 : TensorShape 2 + tensorShapeTest3 = ["batchSize" ~> Vect 128, "batchSize" ~> Vect 13] + +uniqueElemExample1 : UniqueElem "j" ["i" ~> List, "j" ~> BinTree, "i" ~> List] +uniqueElemExample1 = %search + +failing + uniqueElemExampleFail : UniqueElem "x" ["i" ~> List, "j" ~> BinTree] + uniqueElemExampleFail = %search + + uniqueElemExampleFail2 : UniqueElem "i" [ "i" ~> List + , "j" ~> BinTree + , "i" ~> List] + uniqueElemExampleFail2 = %search + +public export +TensorTest1 : TensorShape 3 +TensorTest1 = ["batchSize" ~> Vect 128, "seqLen" ~> List, "feat" ~> Vect 64] + +-- Here proof search does not work (via keycommand), but `%search` does +public export +TensorTest2 : (i : Axis) -> ConsistentWith i [i] +TensorTest2 i = %search + +failing + TensorElemTest2 : Elem "asdf" TensorTest1 + TensorElemTest2 = %search + +{- +public export +countOccurence : AxisName -> TensorShape rank -> Nat +countOccurence str [] = 0 +countOccurence str (a :: as) = case str == a.name of + True => 1 + countOccurence str as + False => countOccurence str as + +public export +countAfterRemoving : AxisName -> TensorShape rank -> Nat +countAfterRemoving str [] = 0 +countAfterRemoving str (a :: as) with (decEq str a.name) + _ | (Yes _) = countAfterRemoving str as + _ | (No _) = 1 + countAfterRemoving str as + +public export +countLessThanRank : (axisN : AxisName) -> + (shape : TensorShape rank) -> + LTE (countOccurence axisN shape) rank +countLessThanRank axisN [] = LTEZero +countLessThanRank axisN ((name ~> _) :: as) with (axisN == name) + _ | True = LTESucc (countLessThanRank axisN as) + _ | False = lteSuccRight (countLessThanRank axisN as) + +-- mutual +-- public export +-- removeAxis : (toRemove : AxisName) -> +-- (shape : TensorShape rank) -> +-- TensorShape (countAfterRemoving toRemove shape) +-- removeAxis toRemove [] = [] +-- removeAxis toRemove (a :: as) with (decEq toRemove a.name) +-- _ | (Yes prf) = removeAxis toRemove as +-- _ | (No contra) = let cProof = consistentAfterRemoving a as toRemove +-- in a :: removeAxis toRemove as +-- +-- ||| Proof that, if `a` is consistent with the shape `as`, then removing any +-- ||| axis from `ss` will still keep `s` consistent with the result +-- public export +-- consistentAfterRemoving : (a : Axis) -> (as : TensorShape rank) -> +-- a `ConsistentWith` as => +-- (toRemove : AxisName) -> +-- a `ConsistentWith` (removeAxis toRemove as) +-- consistentAfterRemoving _ [] _ = NewAxis NotInEmpty +-- consistentAfterRemoving a (ax :: as) @{(NewAxis (NotInNonEmpty neq axisConsistent))} toRemove = ?eifhh_2 +-- consistentAfterRemoving a (ax :: as) @{(ExistingAxis e prf)} toRemove = ?eifhh_1 +-} + + + + +{- +||| Proof that an axis name appears in a tensor shape +||| The proof indirectly carries data of the first index of the occurence +public export +data InShape : AxisName -> TensorShape rank -> Type where + Here : {as : TensorShape rank} -> + (axisName ~> c) `ConsistentWith` as => -- todo add maybe inshape? + -- isno? + -- implement Elem and NotElem, and use them here? + InShape axisName ((axisName ~> c) :: as) + There : {as : TensorShape rank} -> (is : InShape axisName as) => + a `ConsistentWith` as => + InShape axisName (a :: as) + +-- ||| TODO rethink this function? +-- ||| In a tensor shape removes all but the first occurence of an axis +-- ||| removeDuplicates ["x" ~> 1, "y" ~> 3, "x" ~> 1] "x" = ["x" ~> 1, "y" ~> 1] +-- public export +-- removeDuplicates : {n, rank : Nat} -> LTE n rank => +-- (shape : TensorShape rank) -> +-- (axisName : AxisName) -> +-- (inShape : InShape axisName shape n) => +-- IsSucc n => +-- (m : Nat ** TensorShape m) +-- removeDuplicates shape axisName {inShape} {n = 1} +-- = (rank ** shape) +-- removeDuplicates ((_ ~> a) :: as) axisName {inShape = Here @{is}} {n = (S (S k))} +-- = removeDuplicates as axisName {inShape=is} +-- removeDuplicates (s :: as) axisName {inShape = There @{is}} {n = (S (S k))} +-- = let (m ** as') = removeDuplicates as axisName {inShape=is} +-- in (S m ** (::) {axisConsistent=(believe_me ())} s as') + +-- Does tensor contraction allow duplicate axis names +-- * in the input (yes, this is what Einsum also allows) +-- * in the output (no, because otherwise its not clear what should happen) +-- * this means that we can't write `einsum("i,i->ii")` +-- 3) How does contraction work? +-- 3.1) Given `t : Tensor [BatchSize, BatchSize] Double`, what is `dotGeneral t`? +-- +-- Need to figure out how `reduce name t` acts when: +-- 1) `name="BatchSize"` and `t : Tensor [BatchSize, BatchSize] Double` +-- - Should sum up the diagonal? +-- 2) `name="BatchSize"` and `t : Tensor [BatchSize] Double` +-- - Should sum up the vector? +-- 3) `name="BatchSize"` and `t : Tensor [BatchSize, SeqLen, BatchSize] Double` +-- - Should sum up the diagonal slices of SeqLen +-- +-- I suppose this is about iterators +-- iterating through + + + -- ||| If an axis `i` can be added into a singleton list `[j]`, then + -- ||| the axis `j` can be added into a singleton list `[i]` + -- public export + -- axisConsistentSym : {i, j : Axis} -> + -- ConsistentWith i [j] -> ConsistentWith j [i] + -- axisConsistentSym (NewAxis ne) = NewAxis (notElemSym ne) + -- -- For some reason we can't pattern match on `Here`? The proof should still + -- -- be fine... + -- axisConsistentSym (ExistingAxis (There Here) _) impossible + -- axisConsistentSym (ExistingAxis (There (There later)) _) impossible + + + +-- public export +-- data InShape : AxisName -> TensorShape rank -> (n : Nat) -> +-- (ltee : LTE n rank) => Type where +-- Here : {0 n, rank : Nat} -> (lte : LTE n rank) => +-- {as : TensorShape rank} -> InShape axisName as n => +-- (axisName ~> c) `ConsistentWith` as => +-- InShape {rank=S rank} axisName ((axisName ~> c) :: as) (S n) @{LTESucc lte} +-- There : {0 n, rank : Nat} -> (lte : LTE n rank) => +-- {as : TensorShape rank} -> (is : InShape axisName as n) => +-- a `ConsistentWith` as => +-- InShape {rank=S rank} axisName (a :: as) n @{lteSuccRight lte} +-- +-- +-- tensorShapeTest11 : TensorShape 2 +-- tensorShapeTest11 = ["batchSize" ~> Vect 128, "seqLen" ~> List] +-- +-- -- ttt : (n : Nat ** InShape "batchSize" tensorShapeTest11 n @{LTEZero}) +-- +-- mutual +-- public export +-- removeAxis : {n, rank : Nat} -> (lte : LTE n rank) => +-- (shape : TensorShape rank) -> +-- (toRemove : AxisName) -> +-- (inShape : InShape {rank=rank} toRemove shape n @{lte}) => +-- TensorShape (rank `minus` n) +-- removeAxis {rank = 0} shape _ = shape +-- removeAxis {n=S _, rank = S _} ((toRemove ~> c) :: as) toRemove {inShape=Here @{lte} @{is}} +-- = removeAxis as toRemove +-- removeAxis {rank = (S rank')} (a :: as) toRemove {inShape=There @{lte} @{is}} +-- = rewrite minusSuccLTE lte in +-- (let consistencyProof = consistentAfterRemoving a as toRemove {is=is} +-- in a :: removeAxis as toRemove) \ No newline at end of file diff --git a/src/Data/Tensor/Tensor.idr b/src/Data/Tensor/Tensor.idr index daad991..aeeca35 100644 --- a/src/Data/Tensor/Tensor.idr +++ b/src/Data/Tensor/Tensor.idr @@ -14,19 +14,28 @@ import public Data.Container.Base.Object.Instances as Cont import public Data.Num import public Data.Layout -import public Data.Tensor.Axis +import public Data.Tensor.Shape.Axis +import public Data.Tensor.Shape.Shape + import public Misc +import Data.Container.Base.Display2D.CharacterMap +import Data.List.Quantifiers %hide Syntax.WithProof.prefix.(@@) -- used here for indexing {------------------------------------------------------------------------------- {------------------------------------------------------------------------------- -This file defines the main datatype of this repository: `Tensor`, and -utilities and instances for working with it. `Tensor` implements and generalies +This file defines the main datatype of this repository: `Tensor`, and utilities +for working with it. + +`Tensor` implements and generalies 1) `np.array` from NumPy 2) `torch.Tensor` from PyTorch 3) `tf.Tensor` from TensorFlow -to name a few. In this file `Tensor` is simply a wrapper around the extension of an eponymous container: `Cont.Tensor` which itself is simply a composition of containers. +to name a few. + +In this file `Tensor` is a wrapper around the extension of an eponymous container (`Cont.Tensor`) which also provides functionality for working with +axis names. Provided instances for `Tensor` include: Functor, Applicative, Foldable, Naperian, Algebra, Eq, Show, Num, Neg, Abs, @@ -72,6 +81,39 @@ public export (0 t : Tensor shape a) -> Vect rank Cont (.sizes) _ = axisSizes shape +public export +(.indexAxis) : {shape : TensorShape rank} -> + (0 t : Tensor shape a) -> + (axisName : AxisName) -> + (isElem : Elem axisName shape) => + Cont +(.indexAxis) _ axisName = index shape axisName + +public export +(.renameAxis) : {shape : TensorShape rank} -> + (t : Tensor shape a) -> + (axisName : AxisName) -> + (newAxisName : AxisName) -> + Elem axisName shape => + Tensor (rename shape axisName newAxisName) a +(.renameAxis) (MkT t) axisName newAxisName + = MkT $ replace + {p = \cs => Ext (Cont.Tensor cs) a} + (sym $ renamePreservesConts shape axisName newAxisName) + t + +namespace RenameByIndex + public export + (.rename) : {shape : TensorShape rank} -> + (t : Tensor shape a) -> + (axisIndex : Fin rank) -> + (newAxisName : AxisName) -> + Tensor (Data.Tensor.Shape.Shape.RenameByIndex.rename shape axisIndex newAxisName) a + (.rename) (MkT t) axisIndex newAxisName = MkT $ replace + {p = \cs => Ext (Cont.Tensor cs) a} + (sym $ RenameByIndex.renamePreservesConts shape axisIndex newAxisName) + t + namespace SomeTesting public export BatchSize : Axis @@ -112,11 +154,11 @@ Functor (Tensor shape) where namespace NestedTensorUtils public export extract : Tensor [] a -> a - extract (MkT t) = extract t + extract (MkT t) = #> t public export embed : a -> Tensor [] a - embed a = MkT (toScalar a) + embed a = MkT (># a) ||| With the added data of the wrapper around (Ext (Tensor shape) a), this ||| effectively states a list version of the following isomorphism @@ -141,15 +183,15 @@ namespace NestedTensorUtils ||| but it requires non-erased `c` and `cs` public export extractTopExt : {0 cs : TensorShape rank} -> - NewAxisConsistent c cs => - Tensor (c :: cs) a -> Ext (c.cont) (Tensor cs a) + ConsistentWith c cs => + Tensor (c :: cs) a -> Ext c.cont (Tensor cs a) extractTopExt (MkT (sh <| ind)) = shapeExt sh <| \p => MkT $ index sh p <| \p' => ind (p ** p') public export embedTopExt : {0 cs : TensorShape rank} -> - NewAxisConsistent c cs => - Ext (c.cont) (Tensor cs a) -> Tensor (c :: cs) a + ConsistentWith c cs => + Ext c.cont (Tensor cs a) -> Tensor (c :: cs) a embedTopExt e = let tp = GetT . index e in MkT $ (shapeExt e <| shapeExt . tp) <| \(p ** p') => index (tp p) p' @@ -157,22 +199,22 @@ namespace NestedTensorUtils ||| This is useful because container composition adds non-trivial data to the ||| vector type (i.e. `c >@ Scalar` is not equal to `c`) public export - extToVector : Ext (c.cont) a -> Tensor [c] a + extToVector : Ext c.cont a -> Tensor [c] a extToVector e = MkT $ (shapeExt e <| \_ => ()) <| \(cp ** ()) => index e cp public export - vectorToExt : Tensor [c] a -> Ext (c.cont) a + vectorToExt : Tensor [c] a -> Ext c.cont a vectorToExt (MkT t) = shapeExt (shapeExt t) <| \cp => index t (cp ** ()) public export toNestedTensor : {0 cs : TensorShape rank} -> - NewAxisConsistent c cs => + ConsistentWith c cs => Tensor (c :: cs) a -> Tensor [c] (Tensor cs a) toNestedTensor = extToVector . extractTopExt public export fromNestedTensor : {0 cs : TensorShape rank} -> - NewAxisConsistent c cs => + ConsistentWith c cs => Tensor [c] (Tensor cs a) -> Tensor (c :: cs) a fromNestedTensor = embedTopExt . vectorToExt @@ -180,7 +222,7 @@ namespace NestedTensorUtils public export tensorMapFirstAxis : {0 c : Axis} -> {0 cs : TensorShape k} -> {0 ds : TensorShape k'} -> - NewAxisConsistent c cs => NewAxisConsistent c ds => + (ccs : c `ConsistentWith` cs) => (cds : c `ConsistentWith` ds) => (f : Tensor cs a -> Tensor ds a) -> Tensor (c :: cs) a -> Tensor (c :: ds) a tensorMapFirstAxis f = fromNestedTensor . map f . toNestedTensor @@ -191,7 +233,7 @@ namespace NestedTensorUtils public export (<-$>) : {c : Axis} -> {0 cs : TensorShape k} -> {0 ds : TensorShape k'} -> - NewAxisConsistent c cs => NewAxisConsistent c ds => + ConsistentWith c cs => ConsistentWith c ds => (f : Tensor cs a -> Tensor ds a) -> Tensor (c :: cs) a -> Tensor (c :: ds) a (<-$>) = tensorMapFirstAxis @@ -200,52 +242,51 @@ namespace NestedTensorUtils namespace TensorFromConcrete public export concreteTypeTensor : (shape : TensorShape rank) -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => Type -> Type - concreteTypeTensor [] {allConcrete = []} = concreteType {cont=Scalar} - concreteTypeTensor (c :: cs) {allConcrete = Cons @{fc}} - = (concreteType @{fc}) . (concreteTypeTensor cs) + concreteTypeTensor [] @{[]} = concreteType {c=Scalar} + concreteTypeTensor (a :: as) @{ic :: _} + = concreteType @{ic} . (concreteTypeTensor as) public export concreteTypeFunctor : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => Functor (concreteTypeTensor shape) - concreteTypeFunctor {shape = []} {allConcrete = []} - = concreteFunctor {cont=Scalar} - concreteTypeFunctor {shape = (c :: cs)} {allConcrete = Cons @{fc}} - = Functor.Compose @{concreteFunctor @{fc} } @{concreteTypeFunctor} + concreteTypeFunctor {shape = []} @{[]} = concreteFunctor {c=Scalar} + concreteTypeFunctor {shape = (c :: cs)} @{ic :: _} + = Functor.Compose @{concreteFunctor @{ic} } @{concreteTypeFunctor} public export concreteToExtensions : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => concreteTypeTensor shape a -> composeExtensions (conts shape) a - concreteToExtensions {shape = []} {allConcrete = []} ct = fromConcreteTy ct - concreteToExtensions {shape = (_ :: _)} {allConcrete = Cons} ct = - concreteToExtensions <$> fromConcreteTy ct + concreteToExtensions {shape = []} @{[]} ct = fromConcreteTy ct + concreteToExtensions {shape = (_ :: _)} @{(ic :: _)} ct = + concreteToExtensions <$> (fromConcreteTy @{ic} ct) public export extensionsToConcreteType : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => composeExtensions (conts shape) a -> concreteTypeTensor shape a - extensionsToConcreteType {shape = []} {allConcrete = []} ct = toConcreteTy ct - extensionsToConcreteType {shape = (_ :: _)} {allConcrete = Cons @{fc}} ct - = (map @{concreteFunctor @{fc}} extensionsToConcreteType) (toConcreteTy ct) + extensionsToConcreteType {shape = []} @{[]} ct = toConcreteTy ct + extensionsToConcreteType {shape = (_ :: _)} @{ic :: _} ct + = (map @{concreteFunctor @{ic}} extensionsToConcreteType) (toConcreteTy @{ic} ct) public export toTensor : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => concreteTypeTensor shape a -> Tensor shape a toTensor = fromExtensionComposition . concreteToExtensions public export fromTensor : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => Tensor shape a -> concreteTypeTensor shape a fromTensor = extensionsToConcreteType . toExtensionComposition ||| Many containers have a `FromConcrete` instance, allowing them to easily ||| be converted to and from a (usually familiar) Idris type - ||| This works with tensors defined as a fold over contianers, but it requires + ||| This works with tensors defined as a fold over containers, but it requires ||| burdensome shape annotations everywhere ||| The decision was made to wrap that fold in `Tensor` as above, and then ||| (as this isn't a container anymore) provide equally named functions like @@ -253,13 +294,13 @@ namespace TensorFromConcrete ||| detect which one needs to be used at call sites public export fromConcreteTy : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => concreteTypeTensor shape a -> Tensor shape a fromConcreteTy = toTensor public export toConcreteTy : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => Tensor shape a -> concreteTypeTensor shape a toConcreteTy = fromTensor @@ -269,7 +310,7 @@ namespace TensorFromConcrete ||| We read it as a map `>` going into the tensor `#` public export (>#) : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => concreteTypeTensor shape a -> Tensor shape a (>#) = fromConcreteTy @@ -277,7 +318,7 @@ namespace TensorFromConcrete ||| We read it as a map `>` going out of the tensor `#` public export (#>) : {shape : TensorShape rank} -> - (allConcrete : AllConcrete (conts shape)) => + AllC IsConcrete shape => Tensor shape a -> concreteTypeTensor shape a (#>) = toConcreteTy @@ -287,8 +328,8 @@ namespace TensorFromConcrete (>#>) : {rankOld, rankNew : Nat} -> {shapeOld : TensorShape rankOld} -> {shapeNew : TensorShape rankNew} -> - (allConcreteOld : AllConcrete (conts shapeOld)) => - (allConcreteNew : AllConcrete (conts shapeNew)) => + AllC IsConcrete shapeOld => + AllC IsConcrete shapeNew => (Tensor shapeOld a -> Tensor shapeNew b) -> concreteTypeTensor shapeOld a -> concreteTypeTensor shapeNew b (>#>) f ct = #> (f (># ct)) @@ -297,8 +338,8 @@ namespace TensorFromConcrete (#>#) : {rankOld, rankNew : Nat} -> {shapeOld : TensorShape rankOld} -> {shapeNew : TensorShape rankNew} -> - (allConcreteOld : AllConcrete (conts shapeOld)) => - (allConcreteNew : AllConcrete (conts shapeNew)) => + AllC IsConcrete shapeOld => + AllC IsConcrete shapeNew => (concreteTypeTensor shapeOld a -> concreteTypeTensor shapeNew b) -> Tensor shapeOld a -> Tensor shapeNew b (#>#) f t = ># (f (#> t)) @@ -310,6 +351,7 @@ namespace Reshape public export restructure : {cs : TensorShape oldRank} -> {ds : TensorShape newRank} -> Cont.Tensor (conts cs) =%> Cont.Tensor (conts ds) -> + tensorShapesConsistent cs ds => Tensor cs a -> Tensor ds a restructure f = MkT . extMap f . GetT @@ -323,9 +365,10 @@ namespace Reshape public export reshape : {oldShape : TensorShape oldRank} -> {newShape : TensorShape newRank} -> - (oldCub : All IsCubical (conts oldShape)) => (newCub : All IsCubical (conts newShape)) => + All IsCubical (conts oldShape) => All IsCubical (conts newShape) => Tensor oldShape a -> - {auto prf : size (conts oldShape) = size (conts newShape)} -> + (size (conts oldShape) = size (conts newShape)) => + tensorShapesConsistent oldShape newShape => Tensor newShape a reshape t = restructure (reshape DefaultLayoutOrder) t @@ -334,24 +377,20 @@ namespace Reshape -- treeExample1 = ># Node 60 (Node 7 (Leaf (-42)) (Leaf 46)) (Leaf 2) ||| Performs an in-order traversal of a binary tree tensor into a list tensor - public export - traversalExample : Tensor ["binTree" ~> BinTree] Double -> - Tensor ["list" ~> List] Double - traversalExample = restructure (wrapIntoVector inorder) - - -- ||| Down the line, we'll also want to adjust how we perform this - -- ||| transformation depending on the device we perform the computation on. - + -- public export + -- traversalExample2 : Tensor ["binTree" ~> BinTree] Double -> + -- Tensor ["list" ~> List] Double + -- traversalExample2 = restructure (wrapIntoVector inorder) namespace TensorInstances namespace ApplicativeInstance public export tensorReplicate : {shape : TensorShape rank} -> - (allAppl : All TensorMonoid (conts shape)) => + (allAppl : AllC TensorMonoid shape) => (x : a) -> Tensor shape a tensorReplicate {shape = []} = embed - tensorReplicate {shape = (_ :: _), allAppl = (::) _ _} + tensorReplicate {shape = (_ :: _), allAppl = _ :: _} = fromExtensionComposition . pure . toExtensionComposition @@ -359,17 +398,17 @@ namespace TensorInstances public export liftA2Tensor : {shape : TensorShape rank} -> - (allAppl : All TensorMonoid (conts shape)) => + (allAppl : AllC TensorMonoid shape) => Tensor shape a -> Tensor shape b -> Tensor shape (a, b) liftA2Tensor {shape = [], allAppl=[]} (MkT t) (MkT t') = embed (index t (), index t' ()) - liftA2Tensor {shape = (s :: ss), allAppl = (::) _ _} t t' + liftA2Tensor {shape = (s :: ss), allAppl = _ :: _} t t' = embedTopExt $ uncurry liftA2Tensor <$> liftA2 (extractTopExt t) (extractTopExt t') public export {shape : TensorShape rank} -> - (allAppl : All TensorMonoid (conts shape)) => + (allAppl : AllC TensorMonoid shape) => Applicative (Tensor shape) where pure = tensorReplicate fs <*> xs = uncurry ($) <$> liftA2Tensor fs xs @@ -381,7 +420,7 @@ namespace TensorInstances Nil : Eq a => AllEq [] a Cons : {c : Axis} -> {cs : TensorShape k} -> (eq : Eq (c.cont `fullOf` Tensor cs a)) => -- hmm, can be simplified? this would cause unification regarding AllConsistent to become much simpler? - (ne : NewAxisConsistent c cs) => + (ne : c `ConsistentWith` cs) => AllEq (c :: cs) a public export @@ -401,7 +440,7 @@ namespace TensorInstances %hint public export interfacePosEq : {n : Nat} -> InterfaceOnPositions (Cont.Tensor [Vect n]) Eq - interfacePosEq = MkI -- follows from Data.DPair L57 + interfacePosEq = MkI %search -- follows from Data.DPair L57 -- public export -- vectInterfacePos : {n : Nat} -> InterfaceOnPositions (Vect n) DecEq @@ -410,7 +449,7 @@ namespace TensorInstances namespace NumericInstances public export {shape : TensorShape rank} -> - Num a => All TensorMonoid (conts shape) => + Num a => AllC TensorMonoid shape => Num (Tensor shape a) where fromInteger = tensorReplicate . fromInteger t + t' = uncurry (+) <$> liftA2Tensor t t' @@ -418,7 +457,7 @@ namespace TensorInstances public export {shape : TensorShape rank} -> - Neg a => All TensorMonoid (conts shape) => + Neg a => AllC TensorMonoid shape => Neg (Tensor shape a) where negate = (negate <$>) xs - ys = (uncurry (-)) <$> liftA2 xs ys @@ -430,20 +469,20 @@ namespace TensorInstances public export {shape : TensorShape rank} -> - Abs a => All TensorMonoid (conts shape) => + Abs a => AllC TensorMonoid shape => Abs (Tensor shape a) where abs = (abs <$>) public export {shape : TensorShape rank} -> - Fractional a => All TensorMonoid (conts shape) => + Fractional a => AllC TensorMonoid shape => Fractional (Tensor shape a) where t / v = (uncurry (/)) <$> liftA2 t v public export {shape : TensorShape rank} -> Exp a => - All TensorMonoid (conts shape) => + AllC TensorMonoid shape => Exp (Tensor shape a) where exp = (exp <$>) log = (log <$>) @@ -452,23 +491,53 @@ namespace TensorInstances public export {shape : TensorShape rank} -> FromDouble a => - All TensorMonoid (conts shape) => + AllC TensorMonoid shape => FromDouble (Tensor shape a) where fromDouble x = tensorReplicate (fromDouble x) namespace DiagonalAxis + ||| Captures both "diagonal" operation for vector (Naperian containers) and + ||| "concat"-like operation for lists public export - diagonal : {i : Axis} -> + join : {i : Axis} -> SeqMonoid i.cont => (t : Tensor [i, i] a) -> - IsNaperian i.cont => TensorMonoid (i.cont) => Tensor [i] a - diagonal t = restructure diagonal t + join = restructure join + ||| Alias for `join` for Naperian containers public export - tDiag : Tensor ["i" ~~> 2, "i" ~~> 2] Double - tDiag = ># [ [100, 0] - , [0, 47] ] + diagonal : {i : Axis} -> IsNaperian i.cont => + (t : Tensor [i, i] a) -> + Tensor [i] a + diagonal = join + + -- public export + -- diagonalise : {toDiagonalise : AxisName} -> {shape : TensorShape rank} -> + -- (t : Tensor shape a) -> + -- All IsNaperian shape => + -- (inShape : InShape toDiagonalise shape) => + -- Tensor (snd $ removeDuplicates shape toDiagonalise {inShape=inShape}) a + -- diagonalise t = ?diagonalise_rhs + + + + + -- for codiagonal we need zeros + public export + codiagonal : Num a => {i : Axis} -> + Tensor [i] a -> Tensor [i, i] a + codiagonal t = ?codiagonal_rhs + + public export + tDiagVect : Tensor ["i" ~~> 2, "j" ~~> 2] Double + tDiagVect = ># [ [1, 2] + , [3, 4] ] + public export + tDiagList : Tensor ["i" ~> List, "j" ~> List] Double + tDiagList = ># [ [10, 20, 30] + , [3, 4] ] + --public export --diagonal : {k : Nat} -> {shape : TensorShape rank} -> -- (t : Tensor shape a) -> @@ -500,7 +569,7 @@ namespace TensorInstances Cons : {c : Axis} -> {cs : TensorShape k} -> (alg : Algebra (Ext c.cont) (Tensor cs a)) => (rest : AllAlgebra cs a) => - NewAxisConsistent c cs => + ConsistentWith c cs => AllAlgebra (c :: cs) a {- @@ -526,6 +595,13 @@ namespace TensorInstances reduceTensor {allAlg = []} = extract reduceTensor {allAlg = Cons} = reduceTensor . reduce . extractTopExt + public export + reduceFirstAxis : {shape : TensorShape rank} -> + (alg : Algebra (Ext s.cont) (Tensor shape a)) => + ConsistentWith s shape => + Tensor (s :: shape) a -> Tensor shape a + reduceFirstAxis = reduce . extractTopExt + public export {shape : TensorShape rank} -> @@ -542,23 +618,46 @@ namespace TensorInstances -- ||| Since we have non-unique axis labels, this likely needs to be -- ||| implemented after `dot` namespace ReduceAxis - ||| Takes in a tensor `t` and an axis name which we want to reduce along. - ||| Returns a new tensor with all occurences of this axis summed over, - ||| correctly zipping if this axis appears multiple times. - ||| We also ensure that this function can only be called if the axis truly - ||| appears in the tensor, and if its underlying container is finite. - ||| From finite we can get algebra - public export - reduceAxis : {shape : TensorShape rank} -> + + ||| Reduces a tensor along an axis appearing only once in the shape + ||| In the presence of multiple axes (at least in the Naperian case) we + ||| first have to transpose them to the front, and then diagonalise. + ||| Using `IsFinite` instead of `Algebra` because `IsFinite` allows us to + ||| refer to only the container, instead of the underlying type `a` + public export + reduceSingleAxis : {rank : Nat} -> + {shape : TensorShape (S rank)} -> + (toReduce : AxisName) -> + (uElem : UniqueElem toReduce shape) => (t : Tensor shape a) -> - (atm : All TensorMonoid (conts shape)) => - (nap : All IsNaperian (conts shape)) => - {k : Nat} -> - (toReduce : AxisName) -> (inShape : InShape toReduce shape k) => - IsSucc k => + AllC TensorMonoid shape => Num a => - (isFinite : IsFinite (shape.getByName toReduce inShape).cont) => - Tensor (snd $ removeAllOccurrences shape toReduce {inShape=inShape}) a + (isFinite : IsFinite (index shape toReduce)) => + Tensor (Unique.removeAxis toReduce shape) a + reduceSingleAxis {shape = (ax :: as)} toReduce t @{Here} @{_ :: tms} + = let tAlg = algebraFinite ax.cont @{isFinite} (Tensor as a) + in reduce (extractTopExt t) + reduceSingleAxis {shape = (ax :: as)} toReduce t @{There @{ItIsSucc}} @{_ :: _} {isFinite=isFinite} + = tensorMapFirstAxis {cds=consistentAfterRemoving ax as toReduce} + (\t' => reduceSingleAxis toReduce t' {isFinite=isFinite}) t + + + -- ||| Takes in a tensor `t` and an axis name which we want to reduce along. + -- ||| Returns a new tensor with all occurences of this axis summed over, + -- ||| correctly zipping if this axis appears multiple times. + -- ||| We also ensure that this function can only be called if the axis truly + -- ||| appears in the tensor, and if its underlying container is finite. + -- ||| From finite we can get algebra + -- ||| Todo what is the best way to think about this, via algebra or finite? + -- public export + -- reduceAxis : {shape : TensorShape rank} -> + -- (t : Tensor shape a) -> + -- (atm : All TensorMonoid (conts shape)) => + -- (nap : All IsNaperian (conts shape)) => + -- (toReduce : AxisName) -> (inShape : Elem toReduce shape) => + -- Num a => + -- (isFinite : IsFinite (index shape toReduce)) => + -- Tensor (removeAxis toReduce shape) a -- reduceAxis {shape = ((toReduce ~> c) :: as)} {atm=(tm :: tms)} t toReduce {inShape = Here} {k=S k'} {nap} with (k') -- _ | 0 = reduce @{algebraFinite c {isFinite=isFinite} (Tensor as a)} (extractTopExt t) -- this is the last axis to reduce -- reduceAxis {shape = ((toReduce ~> (Nap pos)) :: as)} {atm=(tm :: tms)} t toReduce {inShape = Here} {k=S k'} {nap = (MkIsNaperian pos)} | (S k'') = ?redddd_2_S_0 -- there is at least one more axis after this @@ -705,49 +804,44 @@ namespace TensorInstances -- {shape : List Cont} -> -- Algebra (CTensor shape) a => Algebra (CTensor shape) (CTensor [] a) where -- reduce t = embed $ reduce $ extract <$> t - -t1 : Tensor [4] ["features"] Double -t1 = reduce t0 "batch" - -} namespace FoldableInstance - public export - data AllFoldable : (shape : TensorShape rank) -> Type where - Nil : AllFoldable [] - Cons : {0 cs : TensorShape k} -> - Foldable (Ext c.cont) => - AllFoldable cs => - NewAxisConsistent c cs => - AllFoldable (c :: cs) - public export tensorFoldr : {0 shape : TensorShape rank} -> - (allFoldable : AllFoldable shape) => + (allFoldable : AllC IsFoldable shape) => (a -> acc -> acc) -> acc -> Tensor shape a -> acc tensorFoldr {allFoldable = []} f val t = f (extract t) val - tensorFoldr {allFoldable = Cons} f val t = foldr + tensorFoldr {allFoldable = _ :: _} f val t = foldr (\ct, acc => tensorFoldr f acc ct) val (extractTopExt t) - + public export {shape : TensorShape rank} -> - (allFoldable : AllFoldable shape) => + (allFoldable : AllC IsFoldable shape) => Foldable (Tensor shape) where foldr = tensorFoldr + + -- We need this hint for the example `thisNowWorks` to work + %hint + public export + allFoldableFromAllCubical : {shape : TensorShape k} -> + All IsCubical shape -> AllC IsFoldable shape + allFoldableFromAllCubical {shape=[]} [] = [] + allFoldableFromAllCubical {shape = (_ ~~> n) :: _} (MkIsCubical _ n :: ics) + = %search :: allFoldableFromAllCubical ics + + concreteWorks : Tensor ["a" ~~> 7, "b" ~~> 2] Integer -> Integer + concreteWorks t = foldr (+) 0 t - -- concreteWorks : Tensor [7, 2] ["a", "b"] Integer -> Integer - -- concreteWorks t = foldr (+) 0 t - - -- parametricCTensorWorks : {shape : Vect rank Cont} -> - -- {names : Vect rank String} -> - -- (ac : AllConsistent names shape) => - -- AllFoldable shape => - -- CTensor shape names Integer -> Integer - -- parametricCTensorWorks t = foldr (+) 0 t + parametricCTensorWorks : {shape : TensorShape rank} -> + AllC IsFoldable shape => + Tensor shape Integer -> Integer + parametricCTensorWorks t = foldr (+) 0 t - -- parametricDoesNotWork : {shape : List Nat} -> - -- Tensor shape Integer -> Integer - -- parametricDoesNotWork t = foldr (+) 0 t + thisNowWorks : {shape : TensorShape rank} -> + All IsCubical shape => + Tensor shape Integer -> Integer + thisNowWorks t = foldr (+) 0 t namespace TraversableInstance public export @@ -756,7 +850,7 @@ t1 = reduce t0 "batch" Cons : {0 cs : TensorShape k} -> Traversable (Ext c.cont) => AllTraversable cs => - NewAxisConsistent c cs => + c `ConsistentWith` cs => AllTraversable (c :: cs) public export @@ -771,19 +865,19 @@ t1 = reduce t0 "batch" public export {shape : TensorShape rank} -> (allTraversable : AllTraversable shape) => - (allFoldable : AllFoldable shape) => + (allFoldable : AllC IsFoldable shape) => Traversable (Tensor shape) where traverse = tensorTraverse namespace NaperianInstance public export - transposeMatrix : {i, j : Axis} -> - NewAxisConsistent i [j] => NewAxisConsistent j [i] => + transposeMatrix : {0 i, j : Axis} -> + i `ConsistentWith` [j] => j `ConsistentWith` [i] => (ni : IsNaperian i) => (nj : IsNaperian j) => Tensor [i, j] a -> Tensor [j, i] a - transposeMatrix {ni=(MkIsNaperian _ _)} {nj=(MkIsNaperian _ _)} + transposeMatrix {ni=(MkIsNaperian _ _), nj=(MkIsNaperian _ _)} = restructure transpose - + transposeTest : Tensor ["i" ~~> 2, "j" ~~> 3] a -> Tensor ["j" ~~> 3, "i" ~~> 2] a transposeTest = transposeMatrix @@ -795,60 +889,153 @@ t1 = reduce t0 "batch" {sh : c.cont.Shp} -> Tensor [c] (c.cont.Pos sh) positions = extToVector (positionsCont {sh=sh}) -namespace ShowInstance - public export - data AllShow : (shape : TensorShape rank) -> - (a : Type) -> Type where - Nil : Show a => AllShow [] a - -- for type below, we should be able to define what shExt is without referencing CTensor cs a? - Cons : {0 cs : TensorShape k} -> - Show (c.cont `fullOf` Tensor cs a) => - NewAxisConsistent c cs => - AllShow (c :: cs) a - - public export - show' : {0 rank : Nat} -> - {shape : TensorShape rank} -> - (allShow : AllShow shape a) => - Tensor shape a -> String - show' {allShow = Nil} t = show (extract t) - show' {allShow = Cons @{sh}} t = show (extractTopExt t) - - public export - {shape : TensorShape rank} -> - (allShow : AllShow shape a) => - Show (Tensor shape a) where - show t = show' {allShow = allShow} t - - -- %hint - -- public export - -- allShowCubical : {shape : Vect rank Axis} -> - -- (ac : AxesConsistent shape) => - -- Show a => - -- AllShow shape a - -- allShowCubical {shape=[], ac = []} = Nil - -- allShowCubical {shape=(c :: cs), ac = a::as} - -- = ?allShowCubical_rhs -- Cons @{?oibim} - - -- public export - -- {shape : Vect rank Axis} -> - -- (ac : AllConsistent names (Vect <$> shape)) => - -- Show a => - -- Show (Tensor shape names a) where - -- show t = show' {allShow=allShowCubical} t - -- -- show {shape=(c :: cs)} t = show' {allShow = Cons @{?oiim}} t + namespace ShowInstance + ||| Tensor-context rendering of container extensions. + ||| A separate interface from `Display2D (Ext c a)` because layered + ||| containers need to make separate choices (vertical bracket stacking for + ||| List/Vect, structural tree layout for trees) + public export + interface DisplayNestedCont (0 c : Cont) where + ||| Action a parent of `c` needs to apply to `c` + ||| For instance, trees get boxed if they're multi-line, cubical + ||| containers do not apply any action. Defaults to boxing + boxedForSiblings : Grid Char -> Grid Char + boxedForSiblings = wrapNonEmpty @{DoubleLineBox} + + ||| Given the content of `c` rendered as grids individually, render + ||| the layout given information of how many axes are below, and + ||| information on whether to box children, to keep visual separation + ||| Container can in principle decide whether to use `childBox` or not + displayNestedCont : (axesBelow : Nat) -> + (childBox : Grid Char -> Grid Char) -> + Ext c (Grid Char) -> Grid Char + + public export + DisplayNestedCont List where + boxedForSiblings = id + displayNestedCont 0 _ t = display2D t + displayNestedCont (S k) childBox (_ <| content) = + wrapListBrackets @{AsciiListSyntax} k (childBox <$> toList content) + + public export + {n : Nat} -> DisplayNestedCont (Vect n) where + boxedForSiblings = id + displayNestedCont axesBelow childBox (() <| content) = + displayNestedCont {c = List} axesBelow childBox (n <| content) + + public export + DisplayNestedCont BinTree where + displayNestedCont _ _ t = display2D t + + public export + DisplayNestedCont BinTreeLeaf where + displayNestedCont _ _ t = display2D t + + public export + DisplayNestedCont BinTreeNode where + displayNestedCont _ _ t = display2D t + + public export + DisplayNestedCont Scalar where + displayNestedCont _ _ t = display2D t + + public export + DisplayNestedCont Pair where + boxedForSiblings = id + displayNestedCont _ _ t = display2D t + + public export + data AllDisplay2D : (shape : TensorShape rank) -> (a : Type) -> Type where + Nil : Display2D a => AllDisplay2D [] a + (::) : {k : Nat} -> {0 cs : TensorShape k} -> + DisplayNestedCont c.cont -> -- should this be `Display2D`? + (adTail : AllDisplay2D cs a) -> + ConsistentWith c cs => + (ce : TensorCubEvidence cs) => + AllDisplay2D (c :: cs) a + + ||| Recover the scalar-element `Display2D a` instance from a shape + public export + scalarDisplay : AllDisplay2D shape a => Display2D a + scalarDisplay @{Nil} = %search + scalarDisplay @{(_ :: adTail)} = scalarDisplay @{adTail} - -- showCubical : {shape : List Nat} -> Show a => Tensor shape a -> String - -- showCubical {shape=[]} t = show' {allShow = Nil} t - -- showCubical {shape=(c :: cs)} t = show' {allShow = Cons @{?oiim}} t + ||| Fold through the content of the tensor, and return the maximum + ||| width of the displayed scalar element + public export + maxWidthCubical : {shape : TensorShape rank} -> + (allD : AllDisplay2D shape a) => + AllC IsFoldable shape => + Tensor shape a -> Nat + maxWidthCubical = foldr + (max . gridWidth . display2D @{scalarDisplay @{allD}}) 0 + + ||| Render a cubical tensor given + ||| 1) a specific width of content for each cell + ||| 2) the number of outer bracket levels that need to surround the tensor + public export + renderCubicalWithWidth : {shape : TensorShape rank} -> + (allD : AllDisplay2D shape a) => (allCub : All IsCubical shape) => + (outerWrap, cellWidth : Nat) -> + Tensor shape a -> + Grid Char + renderCubicalWithWidth {allD = Nil} _ cellWidth t = + padGridLeft cellWidth (display2D (extract t)) + renderCubicalWithWidth {allD = (::) {c = _ ~~> n} {cs} _ adTail} + {allCub = MkIsCubical _ _ :: ics} outerWrap cellWidth t = + case extractTopExt t of + (_ <| content) => + let children : List (Grid Char) + children = toList content <&> + renderCubicalWithWidth (S outerWrap) cellWidth + in case cs of + [] => wrappedInnerRow @{AsciiListSyntax} + (defaultLineWidth `minus` 2 * outerWrap) 1 children + _ :: _ => wrapListBrackets @{AsciiListSyntax} 0 + [aboveAllSep padCharacter (pred (length (toList cs))) children] + + ||| Render a tensor we know is cubical: first compute the maximum width of + ||| a scalar element that appears, then render. + public export + display2DTensorCubical : {shape : TensorShape rank} -> + AllDisplay2D shape a => AllC IsFoldable shape => All IsCubical shape => + Tensor shape a -> Grid Char + display2DTensorCubical t = renderCubicalWithWidth 0 (maxWidthCubical t) t + ||| Dispatch between cubical and non-cubical rendering. + ||| Is there a better way than with `TensorCubEvidence`? + public export + dispatchTensorDisplay : {shape : TensorShape rank} -> + AllDisplay2D shape a => + (ce : TensorCubEvidence shape) => + Tensor shape a -> Grid Char + dispatchTensorDisplay @{Nil} t = display2D (extract t) + dispatchTensorDisplay @{allD@(_ :: _)} @{Left prf} t = + display2DTensorCubical @{allD} @{allFoldableFromAllCubical prf} t + dispatchTensorDisplay @{(::) {k} tcd adTail {ce = ceTail}} @{Right _} t = + displayNestedCont k (siblingBox ceTail adTail) + (dispatchTensorDisplay {ce = ceTail} <$> extractTopExt t) + where + ||| Action to apply to a child grid before layering it + siblingBox : TensorCubEvidence ds -> AllDisplay2D ds a -> + (Grid Char -> Grid Char) + siblingBox _ Nil = id + siblingBox (Left _) _ = id + siblingBox (Right _) (tcd :: _) = boxedForSiblings @{tcd} - sst : {shape : TensorShape rank} -> - AllShow shape a => Tensor shape a -> String - sst t = show t + public export + {shape : TensorShape rank} -> + AllDisplay2D shape a => + (ce : TensorCubEvidence shape) => + Display2D (Tensor shape a) where + display2D = dispatchTensorDisplay {ce=ce} - -- sstc : {shape : List Nat} -> Show a => Tensor shape a -> String - -- sstc t = show t + public export + {shape : TensorShape rank} -> + AllDisplay2D shape a => + (ce : TensorCubEvidence shape) => + Show (Tensor shape a) where + show t = assert_total $ showGrid (dispatchTensorDisplay {ce=ce} t) tEx0 : Tensor ["batch" ~~> 3, "features" ~~> 4] Double tEx0 = ># [ [0, 1, 2, 3] @@ -860,10 +1047,11 @@ tEx1 : Tensor ["i" ~~> 2, "j" ~~> 3, "i" ~~> 2] Double tEx1 = ># [ [[0, 1], [2, 3], [4, 5]] , [[6, 7], [8, 9], [10, 11]] ] + namespace TensorContractions public export dotWith : {shape : TensorShape rank} -> - Algebra (Tensor shape) c => All TensorMonoid (conts shape) => + Algebra (Tensor shape) c => AllC TensorMonoid shape => (a -> b -> c) -> Tensor shape a -> Tensor shape b -> Tensor [] c dotWith f xs ys = embed $ reduce $ uncurry f <$> liftA2Tensor xs ys @@ -871,7 +1059,7 @@ namespace TensorContractions public export dot : {shape : TensorShape rank} -> Num a => - Algebra (Tensor shape) a => All TensorMonoid (conts shape) => + Algebra (Tensor shape) a => AllC TensorMonoid shape => Tensor shape a -> Tensor shape a -> Tensor [] a dot xs ys = dotWith (*) xs ys @@ -996,7 +1184,7 @@ namespace TensorContractions public export outerWith : {i, j : Axis} -> TensorMonoid i.cont => TensorMonoid j.cont => - (ac : NewAxisConsistent i [j]) => + (ac : i `ConsistentWith` [j]) => (a -> b -> c) -> Tensor [i] a -> Tensor [j] b -> Tensor [i, j] c outerWith f t t' = @@ -1006,7 +1194,7 @@ namespace TensorContractions public export outer : {i, j : Axis} -> TensorMonoid i.cont => TensorMonoid j.cont => - (ac : NewAxisConsistent i [j]) => + (ac : i `ConsistentWith` [j]) => Num a => Tensor [i] a -> Tensor [j] a -> Tensor [i, j] a outer = outerWith (*) @@ -1015,14 +1203,14 @@ namespace TensorContractions matrixVectorProduct : Num a => {i, j : Axis} -> TensorMonoid j.cont => AllAlgebra [j] a => - (ac : NewAxisConsistent i [j]) => + (ac : i `ConsistentWith` [j]) => Tensor [i, j] a -> Tensor [j] a -> Tensor [i] a matrixVectorProduct m v = dot v <-$> m public export vectorMatrixProduct : Num a => {i, j : Axis} -> TensorMonoid i.cont => - (ac : NewAxisConsistent i [j]) => + (ac : i `ConsistentWith` [j]) => Algebra (Ext i.cont) (Tensor [j] a) => Tensor [i] a -> Tensor [i, j] a -> Tensor [j] a vectorMatrixProduct v m = @@ -1035,9 +1223,9 @@ namespace TensorContractions public export matMul : Num a => {i, j, k : Axis} -> TensorMonoid j.cont => - (ac1 : NewAxisConsistent i [j]) => - (ac2 : NewAxisConsistent j [k]) => - (ac3 : NewAxisConsistent i [k]) => + (ac1 : i `ConsistentWith` [j]) => + (ac2 : j `ConsistentWith` [k]) => + (ac3 : i `ConsistentWith` [k]) => Algebra (Ext j.cont) (Tensor [k] a) => Tensor [i, j] a -> Tensor [j, k] a -> Tensor [i, k] a matMul m1 m2 = fromNestedTensor $ @@ -1046,46 +1234,15 @@ namespace TensorContractions -- "ij,kj->ki" public export matrixMatrixProduct : {i, j, k : Axis} -> - (ac1 : NewAxisConsistent i [j]) => - (ac2 : NewAxisConsistent k [j]) => - (ac3 : NewAxisConsistent k [i]) => + (ac1 : i `ConsistentWith` [j]) => + (ac2 : k `ConsistentWith` [j]) => + (ac3 : k `ConsistentWith` [i]) => Num a => TensorMonoid j.cont => (allAlg : AllAlgebra [j] a) => Tensor [i, j] a -> Tensor [k, j] a -> Tensor [k, i] a matrixMatrixProduct m1 = tensorMapFirstAxis (matrixVectorProduct m1) --- tt0 : CTensor [] [] Integer --- tt0 = pure 13 --- --- fg : CTensor [Vect 7] ["i"] Integer --- fg = pure 5 --- --- fgh : CTensor [Vect 7, Vect 7] ["i", "j"] Integer --- fgh = pure 13 --- --- sht0 : String --- sht0 = show tt0 --- --- fsh0 : Show (Vect 8 `fullOf` (CTensor [] [] Integer)) --- fsh0 = %search --- --- fsh : String --- fsh = show fg --- --- fshh : String --- fshh = show fgh --- --- ll : List' Integer --- ll = fromConcreteTy [1,2,3,4,5] --- --- bt : BinTree' Integer --- bt = fromConcreteTy $ Node 1 (Node 2 (Leaf 3) (Leaf 4)) (Leaf 5) --- --- rt : RoseTree' Char --- rt = fromConcreteTy (Node 'c' [Leaf 'c', Leaf 'd']) - - public export tEx : Tensor ["i" ~~> 3, "j" ~~> 4] Integer tEx = ># [ [1, 2, 3, 4] @@ -1100,39 +1257,59 @@ public export Ex3 : Tensor ["i" ~~> 2, "j" ~~> 6] Integer Ex3 = reshape Ex2 +||| At the moment, only works when the axis name apperas uniquely in the shape namespace IndexingByAxisNames - --public export - --posTypeOfAxisName : {shape : TensorShape rank} -> - -- (indexAxis : AxisName) -> - -- (InShape : indexAxis `InShape` shape) => - -- Type - public export - indexName : {shape : TensorShape rank} -> + data IndexTo : {shape : TensorShape rank} -> (t : Tensor shape a) -> (indexAxis : AxisName) -> - IsSucc k => - (InShape : InShape indexAxis shape k) => - Type - + (inShape : UniqueElem indexAxis shape) => Type where + Nil : {ax : Axis} -> {as : TensorShape rank} -> + ax `ConsistentWith` as => + NotElem ax.name as => + {t : Tensor (ax :: as) a} -> + IndexTo t ax.name + (::) : IsSucc rank => {as : TensorShape rank} -> + {ax : Axis} -> + ax `ConsistentWith` as => + UniqueElem indexAxis as => + IsNo (decEq indexAxis ax.name) => + {t : Tensor (ax :: as) a} -> + (p : ax.cont.Pos (shapeExt (extractTopExt t))) -> + IndexTo {shape=as} (index (extractTopExt t) p) indexAxis -> + IndexTo {shape=ax :: as} t indexAxis + + %name IndexTo ind + + ||| Here "axis shape" here meant in the container sense + public export + indexShapeFw : {shape : TensorShape rank} -> + (t : Tensor shape a) -> + (indexAxis : AxisName) -> + (inShape : UniqueElem indexAxis shape) => + IndexTo t indexAxis -> + (index shape indexAxis).Shp + indexShapeFw t (ax .name) @{Here} Nil = shapeExt (extractTopExt t) + indexShapeFw t indexAxis @{There} (p :: ind) + = indexShapeFw (index (extractTopExt t) p) indexAxis ind + namespace SetterGetter - ||| Machinery for indexing into a Tensor based on absolute positions - ||| It depends on shape, but also on the tensor t itself - ||| Provides a compile-time guarantee that we won't be out of bounds - ||| This dependency is not needed for cubical tensors - ||| Technically, to index we only need the shapes, not the entire tensor + ||| Datatype containing information needed to index into a Tensor + ||| Unlike with cubical tensors, where the underlying tensor is not + ||| necessary, here we require the data of `t : Tensor shape a` too. + ||| Based on absolute positions public export data Index : (shape : TensorShape rank) -> (t : Tensor shape a) -> Type where Nil : {t : Tensor [] a} -> Index [] t - (::) : {cs : TensorShape k} -> - NewAxisConsistent c cs => - {t : Tensor (c :: cs) a} -> - (p : c.cont.Pos (shapeExt (extractTopExt t))) -> - Index cs (index (extractTopExt t) p) -> - Index (c :: cs) t + (::) : {as : TensorShape k} -> + ConsistentWith ax as => + {t : Tensor (ax :: as) a} -> + (p : ax.cont.Pos (shapeExt (extractTopExt t))) -> + Index as (index (extractTopExt t) p) -> + Index (ax :: as) t %name Index is, js @@ -1167,32 +1344,6 @@ namespace SetterGetter -- ts : Tensor ss a := setC (indexC tNested [i]) is x -- in fromNestedTensor $ MkT $ set (GetT tNested) (i ** ()) ts - -- public export - -- t00 : CTensor [Maybe, List] ["m", "l"] Integer - -- t00 = ># Just [10, 20, 30, 40, 50, 60, 70] - - -- public export - -- t11 : Tensor [2, 3] ["i", "j"] Integer - -- t11 = ># [[1,2,3], [4,5,6]] - - -- public export - -- t22 : CTensor [BinTree, List] ["b", "l"] Integer - -- t22 = ># Node [1,2] (Leaf [3,4]) (Leaf [5,6]) - - -- t33 : CTensor [BinTree] ["b"] Integer - -- t33 = ># Node 1 (Leaf 2) (Leaf 3) - - -- t333 : CTensor [Vect 2] ["v"] Integer - -- t333 = ># [1, 2] - - -- t44 : CTensor [] [] Integer - -- t44 = ># 13 - - -- public export - -- jj : Integer - -- jj = index t11 [1, 1] - - namespace CubicalSetterGetter public export data IndexC : Vect rank Nat -> Type where @@ -1269,7 +1420,7 @@ namespace Slice -- (rest : All IsNaperian (conts as)) => -- (nap : IsNaperian a) => -- Log a -> - -- NewAxisConsistent a as => + -- ConsistentWith a as => -- IndexNaperian as {allNap=rest} -> -- IndexNaperian (a :: as) {allNap=(toContNaperian nap :: rest)} @@ -1298,3 +1449,10 @@ namespace Slice -- Log = IndexNaperian shape -- lookup = tensorLookup -- tabulate = tensorTabulate + + + + +public export +treeExample1Test : Tensor ["myTree" ~> BinTree] Double +treeExample1Test = ># Node 60 (Node 7 (Leaf (-42)) (Leaf 46)) (Leaf 2) \ No newline at end of file diff --git a/src/Data/Tensor/Utils.idr b/src/Data/Tensor/Utils.idr index f1ac6a2..a4d4fa8 100644 --- a/src/Data/Tensor/Utils.idr +++ b/src/Data/Tensor/Utils.idr @@ -1,6 +1,7 @@ module Data.Tensor.Utils import Data.Nat -- Add import for Cast +import Data.List import System.Random import Data.Tensor.Tensor @@ -35,25 +36,26 @@ namespace CommonNames Vector c a = Tensor [c] a public export - Matrix : (row, col : Axis) -> NewAxisConsistent row [col] => (a : Type) -> Type + Matrix : (row, col : Axis) -> ConsistentWith row [col] => + (a : Type) -> Type Matrix row col a = Tensor [row, col] a namespace FillZerosOnes public export fill : Num a => {shape : TensorShape rank} -> - All TensorMonoid (conts shape) => + AllC TensorMonoid shape => a -> Tensor shape a fill x = tensorReplicate x public export zeros : Num a => {shape : TensorShape rank} -> - All TensorMonoid (conts shape) => + AllC TensorMonoid shape => Tensor shape a zeros = fill (fromInteger 0) public export ones : Num a => {shape : TensorShape rank} -> - All TensorMonoid (conts shape) => + AllC TensorMonoid shape => Tensor shape a ones = fill (fromInteger 1) @@ -113,9 +115,9 @@ namespace Concatenate public export concat : {shape : TensorShape rank} -> {l : AxisName} -> {x, y : Axis} -> IsCubical x => IsCubical y => - NewAxisConsistent (l ~~> dim x + dim y) shape => - NewAxisConsistent x shape => - NewAxisConsistent y shape => + ConsistentWith (l ~~> dim x + dim y) shape => + ConsistentWith x shape => + ConsistentWith y shape => Tensor (x :: shape) a -> Tensor (y :: shape) a -> Tensor ((l ~~> dim x + dim y) :: shape) a @@ -166,7 +168,7 @@ namespace Max max : {0 shape : TensorShape rank} -> Foldable (Tensor shape) => Ord a => Tensor shape a -> Maybe a - max = maxInList . flatten + max = max . flatten namespace OneHot public export @@ -182,7 +184,7 @@ namespace Triangular (ip : InterfaceOnPositions c.cont MOrd) => TensorMonoid c.cont => (sh : c.cont.Shp) -> Tensor [c, c] Bool - cTriBool {ip = MkI {p}} sh + cTriBool {ip = MkI p} sh = let cPositions = positions {sh=sh} pp : MOrd (c.cont.Pos sh) := p sh in outerWith (flip isSubTerm) cPositions cPositions @@ -217,7 +219,7 @@ namespace Triangular ||| Fill the elements of a tensor `t` with `fill` where `mask` is True public export maskedFill : {shape : TensorShape rank} -> - Num a => All TensorMonoid (conts shape) => + Num a => AllC TensorMonoid shape => (t : Tensor shape a) -> (mask : Tensor shape Bool) -> (fill : a) -> @@ -253,10 +255,12 @@ namespace Misc cumulativeSum : {c : Axis} -> Num a => (isCubical : IsCubical c) => Tensor [c] a -> Tensor [c] a - -- cumulativeSum {isCubical=(MkIsCubical _ n)} t - -- = let tt = map {f=Vect n} (scanl1 (+)) (#> t) - -- - -- in ?qqwer -- #> ((scanl1 (+)) (#> t)) --(#>#) + cumulativeSum {isCubical=(MkIsCubical _ n)} t + = (#>#) (scanl1 (+)) t + + -- let tt = n -- map {f=Vect n} (scanl1 (+)) (#> t) + -- + -- in ?qwerrr -- #> ((scanl1 (+)) (#> t)) --(#>#) @@ -305,11 +309,11 @@ ttt = %search tttt : Traversable (Tensor ["i" ~~> 2]) tttt = %search -testRand : IO (Tensor ["i" ~~> 2, "j" ~~> 3] Double) -testRand = do - t <- random ["i" ~~> 2, "j" ~~> 3] - printLn $ show t - pure t +-- testRand : IO (Tensor ["i" ~~> 2, "j" ~~> 3] Double) +-- testRand = do +-- t <- random ["i" ~~> 2, "j" ~~> 3] +-- printLn $ show t +-- pure t testRand2 : IO (Tensor ["i" ~~> 5] Double) testRand2 = random ["i" ~~> 5] @@ -362,14 +366,4 @@ t1 : Tensor ["i" ~~> 6] Double t1 = arange exMatrix2 : Tensor ["v" ~~> 3, "v" ~~> 3] Double -exMatrix2 = reshape $ arange {stop="v" ~~> 9} - - - -public export -tTest : Tensor ["i" ~~> 800] Double -tTest = arange - -public export -tRes : Tensor ["i" ~~> 2, "j" ~~> 400] Double -tRes = reshape tTest \ No newline at end of file +exMatrix2 = reshape $ arange {stop="l" ~~> 9} \ No newline at end of file diff --git a/src/Data/Tree.idr b/src/Data/Trees.idr similarity index 95% rename from src/Data/Tree.idr rename to src/Data/Trees.idr index 6351d59..2a6657a 100644 --- a/src/Data/Tree.idr +++ b/src/Data/Trees.idr @@ -1,4 +1,8 @@ -module Data.Tree +module Data.Trees + +-- TODO usual name convention is to name it in singular form +-- but for testing purposes it clashes with hedgehog's `Data.Tree` +-- not sure if there's a better solution... import Language.Reflection import Derive.Prelude @@ -272,12 +276,12 @@ namespace RoseTrees fs <*> xs = map {f=RoseTreeSame} (uncurry ($)) $ liftA2RoseTreeSame fs xs - public export - {a : Type} -> Display a => Display (RoseTreeSame a) where - display (Leaf x) = display x - display (Node x rts) - = let (xh ** xw ** dx) = display x - in ?whatt_1 + -- public export + -- {a : Type} -> Display a => Display (RoseTreeSame a) where + -- display (Leaf x) = display x + -- display (Node x rts) + -- = let (xh ** xw ** dx) = display x + -- in ?whatt_1 -- TODO RoseTreeLeaf, RoseTreeNode? diff --git a/src/Data/Unique/List.idr b/src/Data/Unique/List.idr index 51aa68e..5855fd4 100644 --- a/src/Data/Unique/List.idr +++ b/src/Data/Unique/List.idr @@ -48,7 +48,7 @@ namespace UniqueList decElemNotInUniqueList x [] = Yes $ NotInEmptyList x decElemNotInUniqueList x (y :: xs) = case decEq x y of Yes Refl => No $ \(NotInNonEmptyList _ _ {neq}) - => uninhabited @{uniqueUninhabited} neq + => uninhabited @{UninhabitedIsNoRefl} neq No neq => case decElemNotInUniqueList x xs of Yes prf => Yes $ NotInNonEmptyList _ prf {neq=(proofIneqIsNo neq)} No nprf => No $ \(NotInNonEmptyList _ prf') => nprf prf' diff --git a/src/Data/Unique/Vect.idr b/src/Data/Unique/Vect.idr index 09a91a8..e501cd8 100644 --- a/src/Data/Unique/Vect.idr +++ b/src/Data/Unique/Vect.idr @@ -5,303 +5,294 @@ import Decidable.Equality import Decidable.Equality.Core import Misc -%hide Misc.NotElem - -||| A vector with unique elements -||| Requires a mutual block since it is defined in terms of NotElem -namespace UniqueVect - mutual - ||| A list with unique elements, length tracked statically - ||| An element can be inserted if it is not already in the list - ||| Like a Set, but with ordering - ||| @ a The type of the elements in the list - public export - data UniqueVect : (0 n : Nat) -> (0 a : Type) -> DecEq a => Type where - Nil : {0 a : Type} -> DecEq a => UniqueVect 0 a - (::) : {0 a : Type} -> DecEq a => - (x : a) -> - (xs : UniqueVect n a) -> - {auto prf : NotElem x xs} -> - UniqueVect (S n) a +%hide Misc.NotElem.NotElem + +-- `UniqueVect` and `NotElem` are defined in terms of each other +mutual + ||| A vector that cannot contain duplicates + ||| An element can be inserted only if it is not already in the vector + ||| Mathematically can be thought of as an ordered set + ||| @ a The type of the elements in the list + public export + data UniqueVect : (0 n : Nat) -> (0 a : Type) -> DecEq a => Type where + Nil : {0 a : Type} -> DecEq a => UniqueVect 0 a + (::) : {0 a : Type} -> DecEq a => + (x : a) -> + (xs : UniqueVect n a) -> + {auto prf : NotElem x xs} -> + UniqueVect (S n) a - ||| A proof that an element is *not* found in the unique vector - public export - data NotElem : DecEq a => - (x : a) -> (xs : UniqueVect n a) -> Type where - NotInEmptyVect : {0 a : Type} -> DecEq a => (x : a) - -> NotElem {a=a} x [] - NotInNonEmptyVect : {0 a : Type} -> (de : DecEq a) => - {x, y : a} -> - (xs : UniqueVect n a) -> - (ne : NotElem x xs) -> - (neq : IsNo (decEq x y)) => - (prf : NotElem y xs) => - NotElem x (y :: xs) - - namespace All - ||| A proof that elements of a unique vector satisfy a property - public export - data All : DecEq a => (0 p : a -> Type) -> UniqueVect rank a -> Type where - Nil : DecEq a => {0 p : a -> Type} -> All p [] - (::) : DecEq a => {0 p : a -> Type} -> - {0 x : a} -> - {0 xs : UniqueVect k a} -> - NotElem x xs => - p x -> - All p xs -> - All p (x :: xs) - - - ||| A proof that an element is found in a vector with unique elements + ||| A proof that an element `x` is not found in `xs` public export - data Elem : DecEq a => (x : a) -> (xs : UniqueVect n a) -> Type where - Here : DecEq a => - {x : a} -> - {xs : UniqueVect k a} -> - (prf : NotElem x xs) => - Elem x (x :: xs) - There : DecEq a => - {x : a} -> - {xs : UniqueVect k a} -> + data NotElem : DecEq a => (x : a) -> (xs : UniqueVect n a) -> Type where + NotInEmptyVect : DecEq a => (x : a) -> NotElem {a=a} x [] + NotInNonEmptyVect : (de : DecEq a) => + {x, y : a} -> + (xs : UniqueVect n a) -> + (ne : NotElem x xs) -> + (neq : IsNo (decEq x y)) => (prf : NotElem y xs) => - (later : Elem x xs) -> - Elem x (y :: xs) + NotElem x (y :: xs) - ||| An element cannot be in an empty vector - public export - {x : a} -> DecEq a => Uninhabited (Elem x []) where - uninhabited Here impossible - uninhabited (There later) impossible - - ||| Decision procedure for unique vector's Elem - public export - decElemInUniqueVect : DecEq a => - (x : a) -> (xs : UniqueVect n a) -> Dec (Elem x xs) - decElemInUniqueVect x [] = No absurd - decElemInUniqueVect x (y :: ys) = case decEq x y of - Yes Refl => Yes Here - No neq => case decElemInUniqueVect x ys of - Yes prf => Yes $ There prf - No nprf => No $ \case - Here => neq Refl - (There later) => nprf later - - public export - notElem : DecEq a => +||| A proof that an element `x` is found in `xs` +public export +data Elem : DecEq a => (x : a) -> (xs : UniqueVect n a) -> Type where + Here : DecEq a => {x : a} -> - {xs : UniqueVect n a} -> - Not (Elem x xs) -> NotElem x xs - notElem {xs = []} f = NotInEmptyVect x - notElem {xs = (y :: ys)} f with (decEq x y) - _ | (Yes Refl) = absurd (f Here) - _ | (No neq) = NotInNonEmptyVect - {neq=(proofIneqIsNo neq)} ys (notElem (\e => f ?bb)) - - public export - toVect : DecEq a => UniqueVect n a -> Vect n a - toVect [] = [] - toVect (x :: xs) = x :: toVect xs + {xs : UniqueVect k a} -> + (prf : NotElem x xs) => + Elem x (x :: xs) + There : DecEq a => + {x : a} -> + {xs : UniqueVect k a} -> + (prf : NotElem y xs) => + (later : Elem x xs) -> + Elem x (y :: xs) - ||| Converts a vector to a unique vector, removing duplicates if they exist +namespace All + ||| A proof that elements of a unique vector satisfy a property public export - fromVect : DecEq a => Vect n a -> (m : Nat ** UniqueVect m a) - fromVect [] = (0 ** []) - fromVect (x :: xs) = - let (k ** t) = fromVect xs - in case decElemInUniqueVect x t of - Yes prf => (k ** t) - No nprf => (S k ** (::) x t {prf=notElem nprf}) - - - ||| Turn the proof that an element `x` is in a vector into the index of `x` + data All : DecEq a => (0 p : a -> Type) -> UniqueVect rank a -> Type where + Nil : DecEq a => {0 p : a -> Type} -> All p [] + (::) : DecEq a => {0 p : a -> Type} -> + {0 x : a} -> + {0 xs : UniqueVect k a} -> + NotElem x xs => + p x -> + All p xs -> + All p (x :: xs) + + +||| An element cannot be in an empty vector +public export +{x : a} -> DecEq a => Uninhabited (Elem x []) where + uninhabited Here impossible + uninhabited (There later) impossible + +||| A decision procedure for determining whether an element `x` is in `xs` +public export +decElemInUniqueVect : DecEq a => + (x : a) -> (xs : UniqueVect n a) -> Dec (Elem x xs) +decElemInUniqueVect x [] = No absurd +decElemInUniqueVect x (y :: ys) = case decEq x y of + Yes Refl => Yes Here + No neq => case decElemInUniqueVect x ys of + Yes prf => Yes $ There prf + No nprf => No $ \case + Here => neq Refl + (There later) => nprf later + +public export +notElem : DecEq a => + {x : a} -> + {xs : UniqueVect n a} -> + Not (Elem x xs) -> NotElem x xs +notElem {xs = []} f = NotInEmptyVect x +notElem {xs = (y :: ys)} f with (decEq x y) + _ | (Yes Refl) = absurd (f Here) + _ | (No neq) = NotInNonEmptyVect + {neq=(proofIneqIsNo neq)} ys (notElem (\e => f ?bb)) + +public export +toVect : DecEq a => UniqueVect n a -> Vect n a +toVect [] = [] +toVect (x :: xs) = x :: toVect xs + +||| Converts a vector to a unique vector, removing duplicates if they exist +public export +fromVect : DecEq a => Vect n a -> (m : Nat ** UniqueVect m a) +fromVect [] = (0 ** []) +fromVect (x :: xs) = + let (k ** t) = fromVect xs + in case decElemInUniqueVect x t of + Yes prf => (k ** t) + No nprf => (S k ** (::) x t {prf=notElem nprf}) + +||| Turn the proof that an element `x` is in a vector into the index of `x` +public export +indexOf : DecEq a => {0 n : Nat} -> {0 xs : UniqueVect n a} -> + Elem x xs -> Fin n +indexOf Here = FZ +indexOf (There later) = FS (indexOf later) + +public export +length : DecEq a => UniqueVect n a -> Nat +length [] = 0 +length (x :: xs) = 1 + length xs + +||| Drop all the elements up and until the element `x` from a unique vector +public export +drop : DecEq a => + (xs : UniqueVect n a) -> + (elem : Elem x xs) -> + UniqueVect (n `minus` (finToNat (FS (indexOf elem)))) a +drop {n=S k} (_ :: xs) Here = rewrite minusZeroRight k in xs +drop (_ :: xs) (There later) = drop xs later + + +public export +Test1 : UniqueVect 5 String +Test1 = ["a", "b", "c", "d", "e"] + +public export +wher : Elem "c" Test1 +wher = There $ There $ Here + +mutual + ||| Remove element from a unique vector at a given index public export - indexOf : DecEq a => {0 n : Nat} -> {0 xs : UniqueVect n a} -> - Elem x xs -> Fin n - indexOf Here = FZ - indexOf (There later) = FS (indexOf later) - + removeIndex : DecEq a => + {n : Nat} -> + (xs : UniqueVect (S n) a) -> + Fin (S n) -> + UniqueVect n a + removeIndex (x :: xs) FZ = xs + removeIndex {n = (S k)} (x :: xs) (FS i) + = (::) x (removeIndex xs i) {prf=removingElemIsStillNotElem} + + ||| Given a vector `xs` and a proof that `x` is not in `xs`, then even if + ||| we remove any elemens from `xs`, `x` will still not be in the result public export - length : DecEq a => UniqueVect n a -> Nat - length [] = 0 - length (x :: xs) = 1 + length xs - - ||| Drop all the elements up and until the element `x` from a unique vector + removingElemIsStillNotElem : DecEq a => + {n : Nat} -> + {x : a} -> + {xs : UniqueVect (S n) a} -> + {i : Fin (S n)} -> + (ne : NotElem x xs) => + NotElem x (removeIndex xs i) + removingElemIsStillNotElem {xs = (_ :: _)} {ne = (NotInNonEmptyVect _ ne)} {i = FZ} + = ne + removingElemIsStillNotElem {n = (S k)} {xs = (y :: ys)} {ne = (NotInNonEmptyVect ys ne)} {i = (FS j)} + = NotInNonEmptyVect (removeIndex ys j) removingElemIsStillNotElem {prf=removingElemIsStillNotElem} + +||| If `x` is not equal to `y`, then `x` is not in the list `[y]` +||| It seems that Idris manages to discover this proof automatically, so +||| this is not needed in practice +||| Its dual is needed, hence the %hint in for the declaration below +public export +notEqualNotElem : DecEq a => + {x, y : a} -> + (neq : IsNo (decEq x y)) -> + NotElem x [y] +notEqualNotElem _ = NotInNonEmptyVect [] (NotInEmptyVect x) + +%hint +public export +notEqualNotElem2 : DecEq a => + {x, y : a} -> + (neq : IsNo (decEq x y)) -> + NotElem y [x] +notEqualNotElem2 neq = notEqualNotElem {x=y} {y=x} (isNoSym neq) + +||| Number of elements found in any of two unique vectors +||| Effectively, union +public export +numUnique : {n, m : Nat} -> DecEq a => UniqueVect n a -> UniqueVect m a -> Nat +numUnique [] _ = m +numUnique (x :: xs) ys = case decElemInUniqueVect x ys of + Yes _ => numUnique xs ys -- found in ys, so don't count it again + No _ => 1 + numUnique xs ys -- not found in ys, so count it + +||| Number of elements found in both of the two unique vectors +||| Effectively, intersection +public export +numOverlap : {n, m : Nat} -> DecEq a => + UniqueVect n a -> UniqueVect m a -> Nat +numOverlap [] ys = 0 +numOverlap (x :: xs) ys = case decElemInUniqueVect x ys of + Yes _ => 1 + numOverlap xs ys -- found also in ys, so count it + No _ => numOverlap xs ys + +||| Number of elements that are found in one but not both of the two vectors +||| Effectively, symmetric difference +public export +numSymmetricDifference : {n, m : Nat} -> DecEq a => + UniqueVect n a -> UniqueVect m a -> Nat +numSymmetricDifference [] ys = m +numSymmetricDifference (x :: xs) ys = case decElemInUniqueVect x ys of + -- need to pattern match on Elem to propagate length information + Yes Here => numSymmetricDifference xs (removeIndex ys FZ) + Yes (There later) => numSymmetricDifference xs (removeIndex ys (FS (indexOf later))) + No _ => 1 + numSymmetricDifference xs ys + +mutual + public export infixr 5 +++ + + ||| Union public export - drop : DecEq a => + (+++) : DecEq a => (xs : UniqueVect n a) -> - (elem : Elem x xs) -> - UniqueVect (n `minus` (finToNat (FS (indexOf elem)))) a - drop {n=S k} (_ :: xs) Here = rewrite minusZeroRight k in xs - drop (_ :: xs) (There later) = drop xs later - - - public export - Test1 : UniqueVect 5 String - Test1 = ["a", "b", "c", "d", "e"] - - public export - wher : Elem "c" Test1 - wher = There $ There $ Here - - mutual - ||| Remove element from a unique vector at a given index - public export - removeIndex : DecEq a => - {n : Nat} -> - (xs : UniqueVect (S n) a) -> - Fin (S n) -> - UniqueVect n a - removeIndex (x :: xs) FZ = xs - removeIndex {n = (S k)} (x :: xs) (FS i) - = (::) x (removeIndex xs i) {prf=removingElemIsStillNotElem} - - ||| Given a vector `xs` and a proof that `x` is not in `xs`, then even if - ||| we remove any elemens from `xs`, `x` will still not be in the result - public export - removingElemIsStillNotElem : DecEq a => - {n : Nat} -> - {x : a} -> - {xs : UniqueVect (S n) a} -> - {i : Fin (S n)} -> - (ne : NotElem x xs) => - NotElem x (removeIndex xs i) - removingElemIsStillNotElem {xs = (_ :: _)} {ne = (NotInNonEmptyVect _ ne)} {i = FZ} - = ne - removingElemIsStillNotElem {n = (S k)} {xs = (y :: ys)} {ne = (NotInNonEmptyVect ys ne)} {i = (FS j)} - = NotInNonEmptyVect (removeIndex ys j) removingElemIsStillNotElem {prf=removingElemIsStillNotElem} - - - - ||| If `x` is not equal to `y`, then `x` is not in the list `[y]` - ||| It seems that Idris manages to discover this proof automatically, so - ||| this is not needed in practice - ||| Its dual is needed, hence the %hint in for the declaration below - public export - notEqualNotElem : DecEq a => - {x, y : a} -> - (neq : IsNo (decEq x y)) -> - NotElem x [y] - notEqualNotElem _ = NotInNonEmptyVect [] (NotInEmptyVect x) + (ys : UniqueVect m a) -> + UniqueVect (numUnique xs ys) a + [] +++ ys = ys + (x :: xs) +++ ys with (decElemInUniqueVect x ys) + _ | (Yes prf) = xs +++ ys -- x :: (xs +++ ys) + _ | (No nprf) = (::) x (xs +++ ys) {prf=expandUnique {prfy=notElem nprf}} - %hint - public export - notEqualNotElem2 : DecEq a => - {x, y : a} -> - (neq : IsNo (decEq x y)) -> - NotElem y [x] - notEqualNotElem2 neq = notEqualNotElem {x=y} {y=x} (isNoSym neq) - - ||| Number of elements found in any of two unique vectors - ||| Effectively, union - public export - numUnique : {n, m : Nat} -> DecEq a => UniqueVect n a -> UniqueVect m a -> Nat - numUnique [] _ = m - numUnique (x :: xs) ys = case decElemInUniqueVect x ys of - Yes _ => numUnique xs ys -- found in ys, so don't count it again - No _ => 1 + numUnique xs ys -- not found in ys, so count it - - ||| Number of elements found in both of the two unique vectors - ||| Effectively, intersection + ||| If `x` is not in `xs` nor `ys`, then it also won't be in `xs +++ ys` public export - numOverlap : {n, m : Nat} -> DecEq a => - UniqueVect n a -> UniqueVect m a -> Nat - numOverlap [] ys = 0 - numOverlap (x :: xs) ys = case decElemInUniqueVect x ys of - Yes _ => 1 + numOverlap xs ys -- found also in ys, so count it - No _ => numOverlap xs ys - - ||| Number of elements that are found in one but not both of the two vectors - ||| Effectively, symmetric difference - public export - numSymmetricDifference : {n, m : Nat} -> DecEq a => - UniqueVect n a -> UniqueVect m a -> Nat - numSymmetricDifference [] ys = m - numSymmetricDifference (x :: xs) ys = case decElemInUniqueVect x ys of - -- need to pattern match on Elem to propagate length information - Yes Here => numSymmetricDifference xs (removeIndex ys FZ) - Yes (There later) => numSymmetricDifference xs (removeIndex ys (FS (indexOf later))) - No _ => 1 + numSymmetricDifference xs ys - - - - mutual - public export infixr 5 +++ + expandUnique : DecEq a => + {x : a} -> + {xs : UniqueVect n a} -> + {ys : UniqueVect m a} -> + (prfx : NotElem x xs) => + (prfy : NotElem x ys) => + NotElem x (xs +++ ys) + -- todo implement this - ||| Union - public export - (+++) : DecEq a => - (xs : UniqueVect n a) -> - (ys : UniqueVect m a) -> - UniqueVect (numUnique xs ys) a - [] +++ ys = ys - (x :: xs) +++ ys with (decElemInUniqueVect x ys) - _ | (Yes prf) = xs +++ ys -- x :: (xs +++ ys) - _ | (No nprf) = (::) x (xs +++ ys) {prf=expandUnique {prfy=notElem nprf}} - - ||| If `x` is not in `xs` nor `ys`, then it also won't be in `xs +++ ys` - public export - expandUnique : DecEq a => - {x : a} -> - {xs : UniqueVect n a} -> - {ys : UniqueVect m a} -> - (prfx : NotElem x xs) => - (prfy : NotElem x ys) => - NotElem x (xs +++ ys) - -- todo implement this - - mutual - public export - intersect : DecEq a => - (xs : UniqueVect n a) -> - (ys : UniqueVect m a) -> - UniqueVect (numOverlap xs ys) a - intersect [] ys = [] - intersect (x :: xs) ys with (decElemInUniqueVect x ys) - _ | (Yes prf) = (::) x (intersect xs ys) {prf=notElemIntersect} - _ | (No nprf) = intersect xs ys - - ||| If `x` is not in `xs`, then we can intersect `xs` with any other list, - ||| and `x` still wont' be in the result (even if `x` was in the other list) - public export - notElemIntersect : DecEq a => - {x : a} -> - {xs : UniqueVect n a} -> - {ys : UniqueVect m a} -> - (prfx : NotElem x xs) => - (prfy : Elem x ys) => - NotElem x (intersect xs ys) - - - ||| All elements of the intersection of two vectors `xs` and `ys` - ||| will be elements of `xs` +mutual public export - allElemIntersectFst : DecEq a => + intersect : DecEq a => (xs : UniqueVect n a) -> (ys : UniqueVect m a) -> - All (\x => Elem x xs) (intersect xs ys) - allElemIntersectFst = ?allElemIntersect_rhs - - ||| All elements of the intersection of two vectors `xs` and `ys` - ||| will be elements of `ys` + UniqueVect (numOverlap xs ys) a + intersect [] ys = [] + intersect (x :: xs) ys with (decElemInUniqueVect x ys) + _ | (Yes prf) = (::) x (intersect xs ys) {prf=notElemIntersect} + _ | (No nprf) = intersect xs ys + + ||| If `x` is not in `xs`, then we can intersect `xs` with any other list, + ||| and `x` still wont' be in the result (even if `x` was in the other list) + public export + notElemIntersect : DecEq a => + {x : a} -> + {xs : UniqueVect n a} -> + {ys : UniqueVect m a} -> + (prfx : NotElem x xs) => + (prfy : Elem x ys) => + NotElem x (intersect xs ys) + + +||| All elements of the intersection of two vectors `xs` and `ys` +||| will be elements of `xs` +public export +allElemIntersectFst : DecEq a => + (xs : UniqueVect n a) -> + (ys : UniqueVect m a) -> + All (\x => Elem x xs) (intersect xs ys) +allElemIntersectFst = ?allElemIntersect_rhs + +||| All elements of the intersection of two vectors `xs` and `ys` +||| will be elements of `ys` +public export +allElemIntersectSnd : DecEq a => + (xs : UniqueVect n a) -> + (ys : UniqueVect m a) -> + All (\x => Elem x ys) (intersect xs ys) +allElemIntersectSnd = ?allElemIntersect_rhs2 + +mutual public export - allElemIntersectSnd : DecEq a => + symmetricDifference : DecEq a => {n, m : Nat} -> (xs : UniqueVect n a) -> (ys : UniqueVect m a) -> - All (\x => Elem x ys) (intersect xs ys) - allElemIntersectSnd = ?allElemIntersect_rhs2 - - mutual - public export - symmetricDifference : DecEq a => {n, m : Nat} -> - (xs : UniqueVect n a) -> - (ys : UniqueVect m a) -> - UniqueVect (numSymmetricDifference xs ys) a - symmetricDifference [] ys = ys - symmetricDifference (x :: xs) ys = ?aaa - - -- with (decElemInUniqueVect x ys) - -- _ | Yes p = ?aaa - -- _ | No neq = ?bbb + UniqueVect (numSymmetricDifference xs ys) a + symmetricDifference [] ys = ys + symmetricDifference (x :: xs) ys = ?aaa + + -- with (decElemInUniqueVect x ys) + -- _ | Yes p = ?aaa + -- _ | No neq = ?bbb diff --git a/src/Misc.idr b/src/Misc.idr index 95e6ccb..a4397e6 100644 --- a/src/Misc.idr +++ b/src/Misc.idr @@ -1,6 +1,7 @@ module Misc import Data.Nat +import Data.List.Elem import Data.Vect import Data.Vect.Elem import System.Random @@ -16,37 +17,100 @@ import Data.List %hide Builtin.infixr.(#) %hide Data.Vect.Quantifiers.All.index +{------------------------------------------------------------------------------- +{------------------------------------------------------------------------------- +Various utilities necessary for TensorType, but that don't fit anywhere else +Does not depend on any other file within this project + +Some of these feel like they should be in the Idris standard library + +-------------------------------------------------------------------------------} +-------------------------------------------------------------------------------} + namespace IsNo - ||| IsNo is a type Idris can automatically synthesise, in contrast to - ||| in contrast to `Not (x = y)` + ||| The proof that a decidable property leads to a contradiction + ||| `IsNo` is a type Idris can automatically synthesise, unlike `Not` + ||| See example below public export data IsNo : Dec a -> Type where - ItIsNo : {prop : Type} -> {contra : Not prop} -> IsNo (No {prop=prop} contra) + ItIsNo : {prop : Type} -> + {contra : Not prop} -> + IsNo (No {prop=prop} contra) - ||| Can this be simplified? - public export - isNoSym : DecEq a => {x, y : a} -> IsNo (decEq x y) -> IsNo (decEq y x) - isNoSym z with (decEq x y) | (decEq y x) - _ | (No contra1) | (Yes prf) = absurd (contra1 (sym prf)) - _ | _ | (No contra) = ItIsNo + failing + thisOneFails : Not ("i" = "j") + thisOneFails = %search + thisOneDoesnt : IsNo (decEq "i" "j") + thisOneDoesnt = %search + public export - [uniqueUninhabited] {0 a : Type} -> {x : a} -> (de : DecEq a) => - Uninhabited (IsNo (Equality.decEq x x)) where + [UninhabitedIsNoRefl] {x : a} -> DecEq a => + Uninhabited (IsNo (decEq x x)) where uninhabited y with (decEq x x) _ | (Yes _) with (y) _ | ItIsNo impossible _ | (No contra) = contra Refl + public export + isNoSym : DecEq a => {x, y : a} -> IsNo (decEq x y) -> IsNo (decEq y x) + isNoSym z with (decEq x y) | (decEq y x) + _ | (No contra1) | (Yes prf) = absurd (contra1 (sym prf)) + _ | _ | (No contra) = ItIsNo ||| Proof of inequality yields IsNo public export proofIneqIsNo : {x, y : a} -> DecEq a => - Not (x = y) -> (IsNo (Equality.decEq x y)) + Not (x = y) -> IsNo (decEq x y) proofIneqIsNo f with (decEq x y) _ | (Yes prf) = absurd (f prf) _ | (No contra) = ItIsNo +namespace Maybe + public export + data IsNothing : Maybe a -> Type where + ItIsNothing : IsNothing Nothing + + public export + maybeVoidIsNothing : (x : Maybe Void) -> IsNothing x + maybeVoidIsNothing Nothing = ItIsNothing + maybeVoidIsNothing (Just v) = absurd v + + public export + Uninhabited (IsNothing (Just x)) where + uninhabited ItIsNothing impossible + + +namespace NotElem + public export + data NotElem : DecEq a => (x : a) -> (xs : Vect n a) -> Type where + NotInEmptyVect : DecEq a => {0 x : a} -> NotElem x [] + NotInNonEmptyVect : DecEq a => {0 x, y : a} -> + (xs : Vect n a) -> + IsNo (decEq x y) -> + (ne : NotElem x xs) => + NotElem x (y :: xs) + + public export + notEqualNotElem : DecEq a => + {0 x, y : a} -> + (neq : IsNo (decEq x y)) -> + NotElem x [y] + notEqualNotElem neq = NotInNonEmptyVect [] neq + + ||| If an element `i` is not in the singleton list `[j]`, then `j` is not in + ||| the singleton list `[i]` + public export + notElemSym : DecEq a => {i, j : a} -> NotElem i [j] -> NotElem j [i] + notElemSym (NotInNonEmptyVect [] isNo) = notEqualNotElem (isNoSym isNo) + + ||| If an element `i` is in the singleton list `[j]`, then `j` is in the + ||| singleton list `[i]` + public export + elemSym : DecEq a => {i, j : a} -> Vect.Elem.Elem i [j] -> + Vect.Elem.Elem j [i] + elemSym Here = Here + namespace Applicative ||| Definition of liftA2 in terms of (<*>) @@ -67,7 +131,6 @@ namespace Applicative fromInteger = pure . fromInteger - namespace VectFoldable ||| Implementation of Foldable for Vect that is denotationally equivalent to ||| one in Data.Vect, but which does not use `foldrImpl` and therefore @@ -82,76 +145,161 @@ namespace VectFoldable toList' : Vect n a -> List a toList' = foldr @{straightforward} (::) [] -||| Drop the first i elements of a vector -||| Analogous to Data.Vect.drop, except the index is Fin n instead of Nat -public export -drop : (i : Fin (S n)) -> Vect n a -> Vect (minus n (finToNat i)) a -drop FZ xs = rewrite minusZeroRight n in xs -drop (FS i) (x :: xs) = drop i xs + public export + fromList' : (xs : List a) -> Vect (length xs) a + fromList' [] = [] + fromList' (x :: xs) = x :: fromList' xs -namespace DropElem - ||| Drop all the elements up and until the element `x` from a vector +||| Duplicate of utilities for Data.Vect in their Naperian form +namespace Vect public export - drop : DecEq a => - (xs : Vect n a) -> - (elem : Elem x xs) -> - Vect (n `minus` (finToNat (FS (elemToFin elem)))) a - drop {n=S k} (_ :: xs) Here = rewrite minusZeroRight k in xs - drop (_ :: xs) (There later) = drop xs later + sum : Num a => Vect n a -> a + sum xs = foldr @{straightforward} (+) (fromInteger 0) xs + + -- Because of the way foldr for Vect is implemented in Idris + -- we have to use this approach below, otherwise allSuccThenProdSucc breaks + public export + prod : Num a => Vect n a -> a + prod xs = foldr @{straightforward} (*) (fromInteger 1) xs + -- prod [] = fromInteger 1 + -- prod (x :: xs) = x * prod xs + public export + max : Ord a => Vect n a -> Maybe a + max [] = Nothing + max (x :: xs) = case max xs of + Nothing => Just x + Just y => Just (max x y) -public export -data NotElem : DecEq a => (x : a) -> (xs : Vect n a) -> Type where - NotInEmptyVect : DecEq a => {0 x : a} -> NotElem x [] - NotInNonEmptyVect : DecEq a => {0 x, y : a} -> - (xs : Vect n a) -> - IsNo (decEq x y) -> - (ne : NotElem x xs) => - NotElem x (y :: xs) + public export + argmax : Ord a => IsSucc n => Vect n a -> Fin n + argmax [x] = FZ + argmax (x :: x' :: xs) = + let maxRest = argmax (x' :: xs) + in case x > index maxRest (x' :: xs) of + True => FZ + False => FS maxRest + + public export + argmin : Ord a => IsSucc n => Vect n a -> Fin n + argmin = argmax @{Reverse} + + ||| Dual to concat from Data.Vect + public export + unConcat : {n, m : Nat} -> Vect (n * m) a -> Vect n (Vect m a) + unConcat {n = 0} _ = [] + unConcat {n = (S k)} xs = let (f, s) = splitAt m xs + in f :: unConcat s -public export -notEqualNotElem : DecEq a => - {0 x, y : a} -> - (neq : IsNo (decEq x y)) -> - NotElem x [y] -notEqualNotElem neq = NotInNonEmptyVect [] neq + ||| Trim a specified trailing value + public export + dropFromEnd : Eq a => a -> Vect n a -> List a + dropFromEnd c row = reverse (dropWhile (== c) (reverse (toList row))) + + ||| Combination of `cons` and `snoc`: adds an element in front, and at the end + public export + consSnoc : Vect n a -> a -> a -> Vect (2 + n) a + consSnoc xs a b = a :: snoc xs b + + ||| Pad a vector with a specified element to exactly `targetSize` + public export + padToSize : Vect size a -> (targetSize : Nat) -> a -> + LTE size targetSize => + Vect targetSize a + padToSize [] Z c = [] + padToSize [] (S k) c = c :: padToSize [] k c + padToSize (x :: xs) (S k) c = x :: padToSize xs k c @{fromLteSucc %search} + + ||| Drop the first i elements of a vector + ||| Analogous to Data.Vect.drop, except the index is Fin n instead of Nat + public export + drop : (i : Fin (S n)) -> Vect n a -> Vect (minus n (finToNat i)) a + drop FZ xs = rewrite minusZeroRight n in xs + drop (FS i) (x :: xs) = drop i xs + + namespace DropElem + ||| Drop all the elements up and until the element `x` from a vector + public export + drop : DecEq a => + (xs : Vect n a) -> + (elem : Elem x xs) -> + Vect (n `minus` (finToNat (FS (elemToFin elem)))) a + drop {n=S k} (_ :: xs) Here = rewrite minusZeroRight k in xs + drop (_ :: xs) (There later) = drop xs later -||| If an element `i` is not in the singleton list `[j]`, then `j` is not in -||| the singleton list `[i]` -public export -notElemSym : DecEq a => {i, j : a} -> NotElem i [j] -> NotElem j [i] -notElemSym (NotInNonEmptyVect [] isNo) = notEqualNotElem (isNoSym isNo) +namespace List + public export + sum : Num a => List a -> a + sum = foldr (+) (fromInteger 0) -||| If an element `i` is in the singleton list `[j]`, then `j` is in the -||| singleton list `[i]` -public export -elemSym : DecEq a => {i, j : a} -> Elem i [j] -> Elem j [i] -elemSym Here = Here + public export + prod : Num a => List a -> a + prod = foldr (*) (fromInteger 1) -||| This already exists in Data.Vect.Elem, but it is not marked with %hint -emptyIsUninhabited : NotElem "i" [] -emptyIsUninhabited = NotInEmptyVect + public export + listZip : List a -> List b -> List (a, b) + listZip (x :: xs) (y :: ys) = (x, y) :: listZip xs ys + listZip _ _ = [] -fe' : Not ("i" = "j") -fe' = ?fe'_rhs + ||| Map each element along with its zero-based position in the list. + public export + mapWithIndex : (Nat -> a -> b) -> List a -> List b + mapWithIndex f = go 0 + where + go : Nat -> List a -> List b + go _ [] = [] + go i (x :: xs) = f i x :: go (S i) xs + + ||| Split a list into consecutive chunks of size `n` (clamped to at least 1). + ||| The final chunk may be shorter than `n`, this is why the length of the + ||| list is needed as upper bound. + public export + chunksOf : (n : Nat) -> List a -> List (List a) + chunksOf n xs = go (max 1 n) xs (length xs) + where + go : Nat -> List a -> (len : Nat) -> List (List a) + go _ [] _ = [] + go _ ys@(_ :: _) Z = [ys] + go sz ys@(_ :: _) (S f) = case splitAt sz ys of + (h, t) => h :: go sz t f + + public export + max : Ord a => List a -> Maybe a + max [] = Nothing + max (x :: xs) = case max xs of + Nothing => Just x + Just y => Just (max x y) + namespace NonEmpty + public export + max : Ord a => (xs : List a) -> (ne : NonEmpty xs) => a + max [x] {ne=IsNonEmpty} = x + max (x :: y :: xs) {ne=IsNonEmpty} = max x (max (y :: xs)) -fe : IsNo (decEq "i" "j") -fe = ItIsNo + ||| Trim a specified trailing value + public export + dropFromEnd : Eq a => a -> List a -> List a + dropFromEnd c row = reverse (dropWhile (== c) (reverse row)) -ne : NotElem "i" ["j"] -ne = NotInNonEmptyVect [] ItIsNo + ||| Combination of `cons` and `snoc`: adds an element in front, and at the end + public export + consSnoc : List a -> a -> a -> List a + consSnoc xs x y = x :: snoc xs y + ||| Pad a list with a specified element to at least `targetSize` + public export + padToSize : Nat -> a -> List a -> List a + padToSize targetSize padValue xs = + xs ++ replicate (minus targetSize (length xs)) padValue -||| Pointwise Num structure for Applicative functors -public export -[applicativeNum] Num a => Applicative f => Num (f a) where - xs + ys = uncurry (+) <$> liftA2 xs ys - xs * ys = uncurry (*) <$> liftA2 xs ys - fromInteger = pure . fromInteger -||| Duplicate of utilities for Data.Vect in their Naperian form + ||| Drop all the elements after the element `x` from a list + public export + dropAfterElem : (xs : List a) -> (elem : Elem x xs) -> List a + dropAfterElem (x :: _) Here = [x] + dropAfterElem (y :: xs) (There p) = y :: dropAfterElem xs p + namespace VectNaperianUtils ||| Analogue of `(::)` public export @@ -180,62 +328,27 @@ namespace VectNaperianUtils takeFin FZ _ = [] takeFin (FS s) (x :: xs) = x :: takeFin s xs -namespace Vect public export - sum : Num a => Vect n a -> a - sum xs = foldr (+) (fromInteger 0) xs - - -- Because of the way foldr for Vect is implemented in Idris - -- we have to use this approach below, otherwise allSuccThenProdSucc breaks - public export - prod : Num a => Vect n a -> a - prod xs = foldr @{straightforward} (*) (fromInteger 1) xs - -- prod [] = fromInteger 1 - -- prod (x :: xs) = x * prod xs + sum : Num a => {n : Nat} -> (Fin n -> a) -> a + sum {n = 0} _ = 0 + sum {n = (S k)} content = content FZ + sum (content . FS) public export - argmax : Ord a => IsSucc n => Vect n a -> Fin n - argmax [x] = FZ - argmax (x :: x' :: xs) = if x > index maxRest (x' :: xs) then FZ else FS maxRest - where maxRest = argmax (x' :: xs) - - public export - argmin : Ord a => IsSucc n => Vect n a -> Fin n - argmin = argmax @{Reverse} - - ||| Dual to concat from Data.Vect - public export - unConcat : {n, m : Nat} -> Vect (n * m) a -> Vect n (Vect m a) - unConcat {n = 0} _ = [] - unConcat {n = (S k)} xs = let (f, s) = splitAt m xs - in f :: unConcat s - + prod : Num a => {n : Nat} -> (Fin n -> a) -> a + prod = prod . tabulate - -namespace List public export - sum : Num a => List a -> a - sum = foldr (+) (fromInteger 0) + toList : {n : Nat} -> (Fin n -> a) -> List a + toList = toList' . tabulate +namespace FinArithmetic public export - prod : Num a => List a -> a - prod = foldr (*) (fromInteger 1) - - public export - listZip : List a -> List b -> List (a, b) - listZip (x :: xs) (y :: ys) = (x, y) :: listZip xs ys - listZip _ _ = [] - - public export - maxInList : Ord a => List a -> Maybe a - maxInList [] = Nothing - maxInList [x] = Just x - maxInList (x :: xs) = do - mx <- maxInList xs - pure (max x mx) + minusSuccLTE : {n, m : Nat} -> LTE n m -> + minus (S m) n = S (minus m n) + minusSuccLTE {m = 0, n = 0} LTEZero = Refl + minusSuccLTE {m = (S k), n = 0} LTEZero = Refl + minusSuccLTE {m = (S k), n = (S left)} (LTESucc x) = minusSuccLTE x - -namespace FinArithmetic ||| Like weakenN from Data.Fin, but where n is on the other side of + public export weakenN' : (0 n : Nat) -> Fin m -> Fin (n + m) @@ -401,22 +514,6 @@ public export Show a => ((x : a) -> Show (b x)) => Show (DPair a b) where show = mkDepPairShow -||| Interface describing how a type can be displayed as a 2d grid of characters -public export -interface Display (a : Type) where - display : (x : a) -> (h : Nat ** w : Nat ** Vect h ((Vect w) Char)) - --- ||| Any type that implements Display can be shown as a string --- public export --- {a : Type} -> Display a => Show a where --- show x = let (h ** w ** xs) = display x --- ss = toList (intersperse "\n" (pack . toList <$> xs)) -- add intercalate here, and newline --- in fastUnlines ss - --- public export --- Display Char where --- display x = (1 ** 1 ** [[x]]) - -- public export -- Num Unit where -- fromInteger _ = () @@ -624,7 +721,6 @@ public export updateAt : Eq a => (a -> b) -> (a, b) -> (a -> b) updateAt f (i, val) i' = if i == i' then val else f i' - ||| Graph of a dependent function public export graph : {t : a -> Type} -> @@ -704,3 +800,31 @@ namespace Linearity ll2 : {0 n : Nat} -> Vect n a -> Nat ll2 [] = 0 ll2 {n=S t} (x :: xs) = 1 + ll2 xs + + + +public export +testFun : Nat -> (m : Nat ** Vect m Nat) + +testFun2 : Nat -> Vect m Nat + +consume : Vect m a -> Type + +composed : (p : a -> Bool) -> + (xs : Vect n a) -> + consume (snd (filter p xs)) +composed p xs = ?composed_rhs + +-- public export +-- filter : (elem -> Bool) -> Vect len elem -> (p ** Vect p elem) +-- filter p [] = ( _ ** [] ) +-- filter p (x::xs) = +-- let (_ ** tail) = filter p xs +-- in if p x then +-- (_ ** x::tail) +-- else +-- (_ ** tail) + +public export +filter2 : (a -> Bool) -> Vect len a -> Vect p a +filter2 f xs = ?filter2_rhs diff --git a/src/NN/Architectures/Affine.idr b/src/NN/Architectures/Affine.idr index b5cc9a4..feac5e1 100644 --- a/src/NN/Architectures/Affine.idr +++ b/src/NN/Architectures/Affine.idr @@ -8,7 +8,7 @@ import Data.Para public export record AffineLayerParams (x, y : Axis) - {auto ac : NewAxisConsistent y [x]} + {auto ac : y `ConsistentWith` [x]} (a : Type) where constructor MkParams weights : Tensor [y, x] a @@ -16,7 +16,7 @@ record AffineLayerParams public export affineImpl : {x, y : Axis} -> - NewAxisConsistent y [x] => + y `ConsistentWith` [x] => Num a => AllAlgebra [x] a => TensorMonoid x.cont => TensorMonoid y.cont => @@ -26,7 +26,7 @@ affineImpl (input ** (MkParams weights bias)) public export affinePara : {x, y : Axis} -> {a : Type} -> Num a => - NewAxisConsistent y [x] => + y `ConsistentWith` [x] => AllAlgebra [x] a => TensorMonoid x.cont => TensorMonoid y.cont => Tensor [x] a -\-> Tensor [y] a diff --git a/src/NN/Architectures/LossFunctions.idr b/src/NN/Architectures/LossFunctions.idr index 709778b..e993c59 100644 --- a/src/NN/Architectures/LossFunctions.idr +++ b/src/NN/Architectures/LossFunctions.idr @@ -101,10 +101,11 @@ public export SquaredError : {a : Type} -> Num a => Neg a => Loss (Const a) {l=Const a} SquaredError = Additive.Morphism.Instances.SquaredDifference + public export Sum : {n : Axis} -> IsCubical n => Num a => - TensorMonoid n.cont => - (Const (Tensor [n] a)) =%> (Const (Tensor [] a)) + TensorMonoid n.cont => + Const (Tensor [n] a) =%> Const (Tensor [] a) Sum @{MkIsCubical _ n} = !%+ \t => (># reduce t ** \a' => fill (#> a')) public export @@ -114,7 +115,7 @@ Div : {a : Type} -> Num a => Fractional a => Div divBy = !%+ \x => (x <&> (/ divBy) ** \x' => x' <&> (/ divBy)) public export -MeanSquaredError : {n : Axis} -> IsCubical n => TensorMonoid n.cont => +MeanSquaredError : IsCubical n => TensorMonoid n.cont => {a : Type} -> Num a => Neg a => Fractional a => Cast Nat a => Loss (Const (Tensor [n] a)) {l=Const (Tensor [] a)} MeanSquaredError @{MkIsCubical _ n} = SquaredError %>> Sum %>> Div (cast n) diff --git a/src/NN/Architectures/Softargmax.idr b/src/NN/Architectures/Softargmax.idr index ab1e352..130c9cf 100644 --- a/src/NN/Architectures/Softargmax.idr +++ b/src/NN/Architectures/Softargmax.idr @@ -34,7 +34,7 @@ logSoftargmax t = case logSumExp t of ||| When `temperature=0` it reduces to `argmax` public export softargmaxImpl : {i : Axis} -> Fractional a => Exp a => Ord a => Neg a => - Foldable (Tensor [i]) => + IsFoldable i .cont => (allAlg : AllAlgebra [i] a) => {default 1 temperature : a} -> Tensor [i] a -> Tensor [i] a @@ -42,12 +42,13 @@ softargmaxImpl {temperature} t = exp <$> logSoftargmax (t <&> (/ temperature)) ||| Softargmax as a parametric map, with temperature as a parameter -||| TODO the output type should be a distribution tensor, since distributions -||| are applicative? https://glaive-research.org/2025/02/11/Generalized-Transformers-from-Applicative-Functors.html +||| TODO since distribution is an applicative functor (https://glaive-research.org/2025/02/11/Generalized-Transformers-from-Applicative-Functors.html) +||| is there a meaningful notion of the "distribution container"? +||| Is there a sense in which `Dist` is a functor on containers? public export softargmax : {i : Axis} -> {a : Type} -> Fractional a => Exp a => Ord a => Neg a => - Foldable (Tensor [i]) => + IsFoldable i.cont => (allAlg : AllAlgebra [i] a) => Tensor [i] a -\-> Tensor [i] a softargmax = MkPara @@ -56,20 +57,12 @@ softargmax = MkPara -- `Control.Monad.Distribution` and softargmax should probably be merged? +-- todo this is missing beause of a show instance for tensors +-- needs an assert total because it goes through tensors public export {i : Nat} -> Show (Dist i) where - show (MkDist xs) = show (softargmaxImpl {i="softmaxTemp" ~~> i} (># xs)) + show (MkDist xs) = assert_total $ + show @{(?todoTensorShow)} (softargmaxImpl {i="softmaxTemp" ~~> i} (># xs)) inpp : Tensor ["ieva" ~~> 3] Double -inpp = ># [1000, 999, 998] - --- TODO namedSoftargmax --- namedSoftmax : {axis : Type -> Type} --- -> {shape : Vect n ApplF} -> {a : Type} --- -> Functor axis --- => Elem axis shape --- -> TensorA shape a --- -> TensorA shape a --- namedSoftmax {shape = []} axis t impossible -- can't be in vector if vector empty --- namedSoftmax {shape = (axis :: ss)} Here (GTS x) = GTS (?sm <$> x) --- namedSoftmax {shape = (s :: ss)} (There later) (GTS x) = GTS ?namedSoftmax_rhs_4 +inpp = ># [1000, 999, 998] \ No newline at end of file diff --git a/src/NN/Architectures/Transformer/Attention.idr b/src/NN/Architectures/Transformer/Attention.idr index 18c8328..48f7001 100644 --- a/src/NN/Architectures/Transformer/Attention.idr +++ b/src/NN/Architectures/Transformer/Attention.idr @@ -8,9 +8,9 @@ import NN.Architectures.Softargmax public export crossAttention : {a : Type} -> Num a => {inputStructure, crossStructure, features : Axis} -> - (acif : NewAxisConsistent inputStructure [features]) => - (accf : NewAxisConsistent crossStructure [features]) => - (acci : NewAxisConsistent crossStructure [inputStructure]) => + (acif : inputStructure `ConsistentWith` [features]) => + (accf : crossStructure `ConsistentWith` [features]) => + (acci : crossStructure `ConsistentWith` [inputStructure]) => TensorMonoid inputStructure.cont => TensorMonoid features.cont => (allAlg : AllAlgebra [inputStructure, features] a) => {default id causalMask : Tensor [crossStructure, inputStructure] a -> @@ -29,9 +29,9 @@ crossAttention {allAlg=Cons {rest=xx}, causalMask} softargmax q v k = public export selfAttention : {a : Type} -> Num a => {inputStructure, features : Axis} -> - NewAxisConsistent inputStructure [features] => - (TensorMonoid inputStructure.cont) => - (TensorMonoid features.cont) => + inputStructure `ConsistentWith` [features] => + TensorMonoid inputStructure.cont => + TensorMonoid features.cont => (allAlg : AllAlgebra [inputStructure, features] a) => {default id causalMask : Tensor [inputStructure, inputStructure] a -> Tensor [inputStructure, inputStructure] a} -> @@ -52,9 +52,9 @@ record SelfAttentionParams (features : Axis) (a : Type) where public export SAImpl : {a : Type} -> Num a => {inputStructure, features : Axis} -> - (ac : NewAxisConsistent inputStructure [features]) => - (TensorMonoid inputStructure.cont) => - (TensorMonoid features.cont) => + (ac : inputStructure `ConsistentWith` [features]) => + TensorMonoid inputStructure.cont => + TensorMonoid features.cont => (allAlg : AllAlgebra [inputStructure, features] a) => {default id causalMask : Tensor [inputStructure, inputStructure] a -> Tensor [inputStructure, inputStructure] a} -> @@ -72,8 +72,8 @@ SAImpl {allAlg = Cons} {causalMask} softargmax (input ** (MkSAParams queryMat va public export SelfAttention : {a : Type} -> Num a => {inputStructure, features : Axis} -> - NewAxisConsistent inputStructure [features] => - (TensorMonoid inputStructure.cont) => (TensorMonoid features.cont) => + inputStructure `ConsistentWith` [features] => + TensorMonoid inputStructure.cont => TensorMonoid features.cont => (allAlg : AllAlgebra [inputStructure, features] a) => {default id causalMask : Tensor [inputStructure, inputStructure] a -> Tensor [inputStructure, inputStructure] a} -> diff --git a/src/NN/Architectures/Transformer/Definition.idr b/src/NN/Architectures/Transformer/Definition.idr index d6092bc..1f52890 100644 --- a/src/NN/Architectures/Transformer/Definition.idr +++ b/src/NN/Architectures/Transformer/Definition.idr @@ -15,9 +15,9 @@ import NN.Architectures.Utils public export Transformer : {a : Type} -> Num a => Ord a => {inputStructure, features : Axis} -> - (ac : NewAxisConsistent inputStructure [features]) => - (TensorMonoid inputStructure.cont) => - (TensorMonoid features.cont) => + (ac : inputStructure `ConsistentWith` [features]) => + TensorMonoid inputStructure.cont => + TensorMonoid features.cont => (allAlg : AllAlgebra [inputStructure, features] a) => {default id causalMask : Tensor [inputStructure, inputStructure] a -> Tensor [inputStructure, inputStructure] a} -> diff --git a/src/NN/Architectures/Utils.idr b/src/NN/Architectures/Utils.idr index 43e8466..8524f62 100644 --- a/src/NN/Architectures/Utils.idr +++ b/src/NN/Architectures/Utils.idr @@ -7,7 +7,7 @@ import Data.Tensor public export paraMapFirstAxis : {c : Axis} -> {cs : TensorShape rank} -> {ds : TensorShape rank'} -> - NewAxisConsistent c cs => NewAxisConsistent c ds => + c `ConsistentWith` cs => c `ConsistentWith` ds => Num a => (pf : Tensor cs a -\-> Tensor ds a) -> (nonDep : IsNotDependent pf) => diff --git a/tensortype.ipkg b/tensortype.ipkg index 30bb09f..28faab6 100644 --- a/tensortype.ipkg +++ b/tensortype.ipkg @@ -17,6 +17,7 @@ depends = contrib , elab-util , hashmap , finite + , timeit -- modules to install modules = Data.Container.Base @@ -26,11 +27,13 @@ modules = Data.Container.Base , Data.Container.Base.Extension.Instances , Data.Container.Base.Morphism.Definition , Data.Container.Base.Morphism.Instances - , Data.Container.Base.Concrete.Definition - , Data.Container.Base.Concrete.Instances + , Data.Container.Base.Properties.Definitions + , Data.Container.Base.Properties.Instances , Data.Container.Base.Product.Definitions , Data.Container.Base.Product.Interfaces , Data.Container.Base.Product.InterfaceImplementations + , Data.Container.Base.Display2D.CharacterMap + , Data.Container.Base.Display2D.Display2D , Data.Container.Base.Definitions , Data.Container.Base.Instances , Data.Container.Base.InstanceInterfaces @@ -43,20 +46,22 @@ modules = Data.Container.Base , Data.Container.Additive.Morphism.Definition , Data.Container.Additive.Morphism.Instances , Data.Container.Additive.Extension.Definition + , Data.Container.Additive.Properties.Definitions , Data.Container.Additive.Product.Definitions , Data.Container.Additive.Quantifiers , Data.Container.Applicative , Data.Container.Applicative.Object.Instances , Data.Container.Applicative.Extension.Instances - , Data.Container.Applicative.Concrete.Instances + , Data.Container.Applicative.Properties.Instances , Data.Container.Applicative.Product.Interfaces , Data.Container.Applicative.TreeUtils , Data.Container.SubTerm , Data.Tensor - , Data.Tensor.Axis + , Data.Tensor.Shape.Axis + , Data.Tensor.Shape.Shape , Data.Tensor.Tensor , Data.Tensor.Utils @@ -72,12 +77,14 @@ modules = Data.Container.Base , Data.Para , Data.Functor.Naperian + , Data.Functor.Products , Data.Layout - , Data.Tree + , Data.Trees , Data.Unique.Vect , Data.Functor.Algebra , Data.Num , Data.ComMonoid + , Data.ScientificNotation , Data.Autodiff.Ops diff --git a/tests/src/Display2D/Expected.idr b/tests/src/Display2D/Expected.idr new file mode 100644 index 0000000..ed83d7a --- /dev/null +++ b/tests/src/Display2D/Expected.idr @@ -0,0 +1,565 @@ +module Display2D.Expected + +import Data.Tensor +import Hedgehog + +import Display2D.Instances + +tensorVectorExpected : String +tensorVectorExpected = "[0.0 1.0 2.0 3.0 4.0 5.0]" + +tensorMatrixExpected : String +tensorMatrixExpected = """ + [[ 0.0 1.0 2.0 3.0] + [ 4.0 5.0 6.0 7.0] + [ 8.0 9.0 10.0 11.0]] + """ + +tensor3DExpected : String +tensor3DExpected = """ + [[[ 0.0 1.0] + [ 2.0 3.0] + [ 4.0 5.0]] + + [[ 6.0 7.0] + [ 8.0 9.0] + [10.0 11.0]]] + """ + +tensor4DExpected : String +tensor4DExpected = """ + [[[[ 0.0 1.0 2.0 3.0 4.0] + [ 5.0 6.0 7.0 8.0 9.0] + [ 10.0 11.0 12.0 13.0 14.0] + [ 15.0 16.0 17.0 18.0 19.0]] + + [[ 20.0 21.0 22.0 23.0 24.0] + [ 25.0 26.0 27.0 28.0 29.0] + [ 30.0 31.0 32.0 33.0 34.0] + [ 35.0 36.0 37.0 38.0 39.0]] + + [[ 40.0 41.0 42.0 43.0 44.0] + [ 45.0 46.0 47.0 48.0 49.0] + [ 50.0 51.0 52.0 53.0 54.0] + [ 55.0 56.0 57.0 58.0 59.0]]] + + + [[[ 60.0 61.0 62.0 63.0 64.0] + [ 65.0 66.0 67.0 68.0 69.0] + [ 70.0 71.0 72.0 73.0 74.0] + [ 75.0 76.0 77.0 78.0 79.0]] + + [[ 80.0 81.0 82.0 83.0 84.0] + [ 85.0 86.0 87.0 88.0 89.0] + [ 90.0 91.0 92.0 93.0 94.0] + [ 95.0 96.0 97.0 98.0 99.0]] + + [[100.0 101.0 102.0 103.0 104.0] + [105.0 106.0 107.0 108.0 109.0] + [110.0 111.0 112.0 113.0 114.0] + [115.0 116.0 117.0 118.0 119.0]]]] + """ + +tensor5DExpected : String +tensor5DExpected = """ + [[[[[ 0.0 1.0 2.0 3.0 4.0 5.0] + [ 6.0 7.0 8.0 9.0 10.0 11.0] + [ 12.0 13.0 14.0 15.0 16.0 17.0] + [ 18.0 19.0 20.0 21.0 22.0 23.0] + [ 24.0 25.0 26.0 27.0 28.0 29.0]] + + [[ 30.0 31.0 32.0 33.0 34.0 35.0] + [ 36.0 37.0 38.0 39.0 40.0 41.0] + [ 42.0 43.0 44.0 45.0 46.0 47.0] + [ 48.0 49.0 50.0 51.0 52.0 53.0] + [ 54.0 55.0 56.0 57.0 58.0 59.0]] + + [[ 60.0 61.0 62.0 63.0 64.0 65.0] + [ 66.0 67.0 68.0 69.0 70.0 71.0] + [ 72.0 73.0 74.0 75.0 76.0 77.0] + [ 78.0 79.0 80.0 81.0 82.0 83.0] + [ 84.0 85.0 86.0 87.0 88.0 89.0]] + + [[ 90.0 91.0 92.0 93.0 94.0 95.0] + [ 96.0 97.0 98.0 99.0 100.0 101.0] + [102.0 103.0 104.0 105.0 106.0 107.0] + [108.0 109.0 110.0 111.0 112.0 113.0] + [114.0 115.0 116.0 117.0 118.0 119.0]]] + + + [[[120.0 121.0 122.0 123.0 124.0 125.0] + [126.0 127.0 128.0 129.0 130.0 131.0] + [132.0 133.0 134.0 135.0 136.0 137.0] + [138.0 139.0 140.0 141.0 142.0 143.0] + [144.0 145.0 146.0 147.0 148.0 149.0]] + + [[150.0 151.0 152.0 153.0 154.0 155.0] + [156.0 157.0 158.0 159.0 160.0 161.0] + [162.0 163.0 164.0 165.0 166.0 167.0] + [168.0 169.0 170.0 171.0 172.0 173.0] + [174.0 175.0 176.0 177.0 178.0 179.0]] + + [[180.0 181.0 182.0 183.0 184.0 185.0] + [186.0 187.0 188.0 189.0 190.0 191.0] + [192.0 193.0 194.0 195.0 196.0 197.0] + [198.0 199.0 200.0 201.0 202.0 203.0] + [204.0 205.0 206.0 207.0 208.0 209.0]] + + [[210.0 211.0 212.0 213.0 214.0 215.0] + [216.0 217.0 218.0 219.0 220.0 221.0] + [222.0 223.0 224.0 225.0 226.0 227.0] + [228.0 229.0 230.0 231.0 232.0 233.0] + [234.0 235.0 236.0 237.0 238.0 239.0]]] + + + [[[240.0 241.0 242.0 243.0 244.0 245.0] + [246.0 247.0 248.0 249.0 250.0 251.0] + [252.0 253.0 254.0 255.0 256.0 257.0] + [258.0 259.0 260.0 261.0 262.0 263.0] + [264.0 265.0 266.0 267.0 268.0 269.0]] + + [[270.0 271.0 272.0 273.0 274.0 275.0] + [276.0 277.0 278.0 279.0 280.0 281.0] + [282.0 283.0 284.0 285.0 286.0 287.0] + [288.0 289.0 290.0 291.0 292.0 293.0] + [294.0 295.0 296.0 297.0 298.0 299.0]] + + [[300.0 301.0 302.0 303.0 304.0 305.0] + [306.0 307.0 308.0 309.0 310.0 311.0] + [312.0 313.0 314.0 315.0 316.0 317.0] + [318.0 319.0 320.0 321.0 322.0 323.0] + [324.0 325.0 326.0 327.0 328.0 329.0]] + + [[330.0 331.0 332.0 333.0 334.0 335.0] + [336.0 337.0 338.0 339.0 340.0 341.0] + [342.0 343.0 344.0 345.0 346.0 347.0] + [348.0 349.0 350.0 351.0 352.0 353.0] + [354.0 355.0 356.0 357.0 358.0 359.0]]]] + + + + [[[[360.0 361.0 362.0 363.0 364.0 365.0] + [366.0 367.0 368.0 369.0 370.0 371.0] + [372.0 373.0 374.0 375.0 376.0 377.0] + [378.0 379.0 380.0 381.0 382.0 383.0] + [384.0 385.0 386.0 387.0 388.0 389.0]] + + [[390.0 391.0 392.0 393.0 394.0 395.0] + [396.0 397.0 398.0 399.0 400.0 401.0] + [402.0 403.0 404.0 405.0 406.0 407.0] + [408.0 409.0 410.0 411.0 412.0 413.0] + [414.0 415.0 416.0 417.0 418.0 419.0]] + + [[420.0 421.0 422.0 423.0 424.0 425.0] + [426.0 427.0 428.0 429.0 430.0 431.0] + [432.0 433.0 434.0 435.0 436.0 437.0] + [438.0 439.0 440.0 441.0 442.0 443.0] + [444.0 445.0 446.0 447.0 448.0 449.0]] + + [[450.0 451.0 452.0 453.0 454.0 455.0] + [456.0 457.0 458.0 459.0 460.0 461.0] + [462.0 463.0 464.0 465.0 466.0 467.0] + [468.0 469.0 470.0 471.0 472.0 473.0] + [474.0 475.0 476.0 477.0 478.0 479.0]]] + + + [[[480.0 481.0 482.0 483.0 484.0 485.0] + [486.0 487.0 488.0 489.0 490.0 491.0] + [492.0 493.0 494.0 495.0 496.0 497.0] + [498.0 499.0 500.0 501.0 502.0 503.0] + [504.0 505.0 506.0 507.0 508.0 509.0]] + + [[510.0 511.0 512.0 513.0 514.0 515.0] + [516.0 517.0 518.0 519.0 520.0 521.0] + [522.0 523.0 524.0 525.0 526.0 527.0] + [528.0 529.0 530.0 531.0 532.0 533.0] + [534.0 535.0 536.0 537.0 538.0 539.0]] + + [[540.0 541.0 542.0 543.0 544.0 545.0] + [546.0 547.0 548.0 549.0 550.0 551.0] + [552.0 553.0 554.0 555.0 556.0 557.0] + [558.0 559.0 560.0 561.0 562.0 563.0] + [564.0 565.0 566.0 567.0 568.0 569.0]] + + [[570.0 571.0 572.0 573.0 574.0 575.0] + [576.0 577.0 578.0 579.0 580.0 581.0] + [582.0 583.0 584.0 585.0 586.0 587.0] + [588.0 589.0 590.0 591.0 592.0 593.0] + [594.0 595.0 596.0 597.0 598.0 599.0]]] + + + [[[600.0 601.0 602.0 603.0 604.0 605.0] + [606.0 607.0 608.0 609.0 610.0 611.0] + [612.0 613.0 614.0 615.0 616.0 617.0] + [618.0 619.0 620.0 621.0 622.0 623.0] + [624.0 625.0 626.0 627.0 628.0 629.0]] + + [[630.0 631.0 632.0 633.0 634.0 635.0] + [636.0 637.0 638.0 639.0 640.0 641.0] + [642.0 643.0 644.0 645.0 646.0 647.0] + [648.0 649.0 650.0 651.0 652.0 653.0] + [654.0 655.0 656.0 657.0 658.0 659.0]] + + [[660.0 661.0 662.0 663.0 664.0 665.0] + [666.0 667.0 668.0 669.0 670.0 671.0] + [672.0 673.0 674.0 675.0 676.0 677.0] + [678.0 679.0 680.0 681.0 682.0 683.0] + [684.0 685.0 686.0 687.0 688.0 689.0]] + + [[690.0 691.0 692.0 693.0 694.0 695.0] + [696.0 697.0 698.0 699.0 700.0 701.0] + [702.0 703.0 704.0 705.0 706.0 707.0] + [708.0 709.0 710.0 711.0 712.0 713.0] + [714.0 715.0 716.0 717.0 718.0 719.0]]]]] + """ + + +public export +cubicalTensorGroup : Group +cubicalTensorGroup = MkGroup "Cubical tensor printing" + [ ("Print vector ", property1 $ show tensorVector === tensorVectorExpected) + , ("Print matrix ", property1 $ show tensorMatrix === tensorMatrixExpected) + , ("Print 3D tensor ", property1 $ show tensor3D === tensor3DExpected) + , ("Print 4D tensor ", property1 $ show tensor4D === tensor4DExpected) + , ("Print 5D tensor ", property1 $ show tensor5D === tensor5DExpected) ] + + + +longVectorExpected : String +longVectorExpected = """ + [3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0] + """ + +longVectorReshapedExpected : String +longVectorReshapedExpected = """ + [[3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0] + [3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 + 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0] ] + """ + +longVector2Expected : String +longVector2Expected = """ + [ 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0 + 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0 21.0 22.0 23.0 + 24.0 25.0 26.0 27.0 28.0 29.0 30.0 31.0 32.0 33.0 34.0 35.0 + 36.0 37.0 38.0 39.0 40.0 41.0 42.0 43.0 44.0 45.0 46.0 47.0 + 48.0 49.0 50.0 51.0 52.0 53.0 54.0 55.0 56.0 57.0 58.0 59.0 + 60.0 61.0 62.0 63.0 64.0 65.0 66.0 67.0 68.0 69.0 70.0 71.0 + 72.0 73.0 74.0 75.0 76.0 77.0 78.0 79.0 80.0 81.0 82.0 83.0 + 84.0 85.0 86.0 87.0 88.0 89.0 90.0 91.0 92.0 93.0 94.0 95.0 + 96.0 97.0 98.0 99.0 100.0 101.0 102.0 103.0 104.0 105.0 106.0 107.0 + 108.0 109.0 110.0 111.0 112.0 113.0 114.0 115.0 116.0 117.0 118.0 119.0 + 120.0 121.0 122.0 123.0 124.0 125.0 126.0 127.0 128.0 129.0 130.0 131.0 + 132.0 133.0 134.0 135.0 136.0 137.0 138.0 139.0 140.0 141.0 142.0 143.0 + 144.0 145.0 146.0 147.0 148.0 149.0 150.0 151.0 152.0 153.0 154.0 155.0 + 156.0 157.0 158.0 159.0 160.0 161.0 162.0 163.0 164.0 165.0 166.0 167.0 + 168.0 169.0 170.0 171.0 172.0 173.0 174.0 175.0 176.0 177.0 178.0 179.0 + 180.0 181.0 182.0 183.0 184.0 185.0 186.0 187.0 188.0 189.0 190.0 191.0 + 192.0 193.0 194.0 195.0 196.0 197.0 198.0 199.0 200.0 201.0 202.0 203.0 + 204.0 205.0 206.0 207.0 208.0 209.0 210.0 211.0 212.0 213.0 214.0 215.0 + 216.0 217.0 218.0 219.0 220.0 221.0 222.0 223.0 224.0 225.0 226.0 227.0 + 228.0 229.0 230.0 231.0 232.0 233.0 234.0 235.0 236.0 237.0 238.0 239.0 + 240.0 241.0 242.0 243.0 244.0 245.0 246.0 247.0 248.0 249.0 250.0 251.0 + 252.0 253.0 254.0 255.0 256.0 257.0 258.0 259.0 260.0 261.0 262.0 263.0 + 264.0 265.0 266.0 267.0 268.0 269.0 270.0 271.0 272.0 273.0 274.0 275.0 + 276.0 277.0 278.0 279.0 280.0 281.0 282.0 283.0 284.0 285.0 286.0 287.0 + 288.0 289.0 290.0 291.0 292.0 293.0 294.0 295.0 296.0 297.0 298.0 299.0] + """ + +longTensorsGroup : Group +longTensorsGroup = MkGroup "Long tensors printing" + [ ("Print long vector", property1 $ show longVector === longVectorExpected) + , ("Print long vector reshaped", property1 $ show longVectorReshaped === longVectorReshapedExpected) + , ("Print long vector 2 ", property1 $ show longVector2 === longVector2Expected) ] + +vectorDecimalExpected : String +vectorDecimalExpected = "[ 0.0 1.2346 1.0e-07]" + +matrixDecimalExpected : String +matrixDecimalExpected = """ + [[ 0.01 0.001 0.0001 1.0e-05 1.0e-06] + [ 10.0 20.0 30.0 40.0 50.0]] + """ + +matrixDecimal2Expected : String +matrixDecimal2Expected = """ + [[ 1000.0 10000.0 100000.0 1.0e+06 1.0e+07] + [ 40.0 50.0 60.0 70.0 80.0]] + """ + +cubicalTensorsDecimalGroup : Group +cubicalTensorsDecimalGroup = MkGroup "Cubical tensors decimal printing" + [ ("Print vector decimal", property1 $ show vectorDecimal === vectorDecimalExpected) + , ("Print matrix decimal", property1 $ show matrixDecimal === matrixDecimalExpected) + , ("Print matrix decimal 2", property1 $ show matrixDecimal2 === matrixDecimal2Expected) ] + + +treeExample1Expected : String +treeExample1Expected = """ + 60.0 + │ + ├─ 7.0 + │ │ + │ ├─ -42.0 + │ │ + │ └─ 46.0 + │ + └─ 2.0 + """ + +treeExample2Expected : String +treeExample2Expected = """ + 5.0 + │ + ├─ 100.0 + │ + └─ 4.0 + """ + +treeExample3Expected : String +treeExample3Expected = """ + [4.0 1.0] + │ + ├─ [17.0 4.0] + │ │ + │ ├─ · + │ │ + │ └─ · + │ + └─ · + """ + +treeExample4Expected : String +treeExample4Expected = """ + ╔════════════════╗ + ║· ║ + ║│ ║ + ║├─ [1.0 2.0 3.0]║ + ║│ ║ + ║└─ [4.0 5.0 6.0]║ + ╚════════════════╝ + │ + ├─ · + │ + └─ ╔═══════════════════╗ + ║[178.0 -43.0 63.0]║ + ╚═══════════════════╝ + │ + ├─ · + │ + └─ · + """ + +treeExample5Expected : String +treeExample5Expected = """ + · + │ + ├─ · + │ │ + │ ├─ [ 1.0 -1.0] + │ │ + │ └─ · + │ │ + │ ├─ [0.5 1.2] + │ │ + │ └─ [ 0.3 -0.2] + │ + └─ [-0.3 1.2] + """ + +treeExample6Expected : String +treeExample6Expected = """ + ╔═════════════════════════════════════════════════════════════════════════╗ + ║[3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 ║ + ║ 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0] ║ + ╚═════════════════════════════════════════════════════════════════════════╝ + │ + ├─ ╔═════════════════════════════════════════════════════════════════════════╗ + │ ║[1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ║ + │ ║ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0] ║ + │ ╚═════════════════════════════════════════════════════════════════════════╝ + │ + └─ ╔═════════════════════════════════════════════════════════════════════════╗ + ║[2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ║ + ║ 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0] ║ + ╚═════════════════════════════════════════════════════════════════════════╝ + """ + +treeExample7Expected : String +treeExample7Expected = """ + · + │ + ├─ · + │ │ + │ ├─ -42.0 + │ │ + │ └─ 46.0 + │ + └─ 2.0 + """ + +treeExample8Expected : String +treeExample8Expected = """ + 60.0 + │ + ├─ 7.0 + │ │ + │ ├─ · + │ │ + │ └─ · + │ + └─ · + """ + +treeTensorsGroup : Group +treeTensorsGroup = MkGroup "Tree tensors printing" + [ ("Print tree example 1", property1 $ show treeExample1 === treeExample1Expected) + , ("Print tree example 2", property1 $ show treeExample2 === treeExample2Expected) + , ("Print tree example 3", property1 $ show treeExample3 === treeExample3Expected) + , ("Print tree example 4", property1 $ show treeExample4 === treeExample4Expected) + , ("Print tree example 5", property1 $ show treeExample5 === treeExample5Expected) + , ("Print tree example 6", property1 $ show treeExample6 === treeExample6Expected) + , ("Print tree example 7", property1 $ show treeExample7 === treeExample7Expected) + , ("Print tree example 8", property1 $ show treeExample8 === treeExample8Expected) ] + + +listExample1Expected : String +listExample1Expected = """ + [[1.0, 2.0, 3.0] + ,[4.0, 5.0, 6.0]] + """ + +listExample2Expected : String +listExample2Expected = """ + [[1.0, 2.0, 3.0] + ,[4.0] + ,[5.0, 6.0, 7.0, 8.0]] + """ + +listExample3Expected : String +listExample3Expected = """ + [[1.0 2.0] + ,[3.0 4.0] + ,[5.0 6.0] + ,[7.0 8.0]] + """ + +listExample4Expected : String +listExample4Expected = """ + [[[1.0, 2.0, 3.0, 4.0] + ,[6.0] + ,[7.0, 8.0, 9.0] ] + + ,[[4.0, 5.0] + ,[9.0, 9.0, 9.0, 9.0, 9.0] + ,[4.0, 2.0, 0.0] + ,[999.0] ]] + """ + +listExample5Expected : String +listExample5Expected = """ + [4.0 + ,╔══════╗ + ║· ║ + ║│ ║ + ║├─ 3.0║ + ║│ ║ + ║└─ 4.0║ + ╚══════╝ + ,╔═════════╗ + ║· ║ + ║│ ║ + ║├─ · ║ + ║│ │ ║ + ║│ ├─ 5.0║ + ║│ │ ║ + ║│ └─ 6.0║ + ║│ ║ + ║└─ 7.0 ║ + ╚═════════╝ + ,╔══════╗ + ║· ║ + ║│ ║ + ║├─ 7.0║ + ║│ ║ + ║└─ 8.0║ + ╚══════╝ ] + """ + +listTensorsGroup : Group +listTensorsGroup = MkGroup "List tensors printing" + [ ("Print list example 1", property1 $ show listExample1 === listExample1Expected) + , ("Print list example 2", property1 $ show listExample2 === listExample2Expected) + , ("Print list example 3", property1 $ show listExample3 === listExample3Expected) + , ("Print list example 4", property1 $ show listExample4 === listExample4Expected) + , ("Print list example 5", property1 $ show listExample5 === listExample5Expected) ] + + +public export +runTests : IO () +runTests = test + [ cubicalTensorGroup + , longTensorsGroup + , cubicalTensorsDecimalGroup + , treeTensorsGroup + , listTensorsGroup ] \ No newline at end of file diff --git a/tests/src/Display2D/Instances.idr b/tests/src/Display2D/Instances.idr new file mode 100644 index 0000000..80938d0 --- /dev/null +++ b/tests/src/Display2D/Instances.idr @@ -0,0 +1,190 @@ +module Display2D.Instances + +import Data.Tensor + +namespace CubicalTensors + public export + tensorVector : Tensor ["l" ~~> 6] Double + tensorVector = arange + + public export + tensorMatrix : Tensor ["j" ~~> 3, "k" ~~> 4] Double + tensorMatrix = reshape $ arange {stop="total" ~~> 12} + + public export + tensor3D : Tensor ["i" ~~> 2, "j" ~~> 3, "k" ~~> 2] Double + tensor3D = reshape $ arange {stop="total" ~~> 12} + + public export + tensor4D : Tensor ["i" ~~> 2, "j" ~~> 3, "k" ~~> 4, "l" ~~> 5] Double + tensor4D = reshape $ arange {stop="total" ~~> 120} + + public export + tensor5D : Tensor ["i" ~~> 2, "j" ~~> 3, "k" ~~> 4, "l" ~~> 5, "m" ~~> 6] Double + tensor5D = reshape $ arange {stop="total" ~~> 720} + + +namespace LongTensors + public export + vectTest : Vect 156 Double + vectTest = replicate 156 3 + + public export + longVector : Tensor ["i" ~~> 156] Double + longVector = ># vectTest + + public export + longVectorReshaped : Tensor ["j" ~~> 2, "k" ~~> 78] Double + longVectorReshaped = reshape longVector + + public export + longVector2 : Tensor ["i" ~~> 300] Double + longVector2 = arange + +namespace CubicalTensorsDecimal + public export + vectorDecimal : Tensor ["i" ~~> 3] Double + vectorDecimal = ># [0.000, 1.23456, 0.0000001] + + public export + matrixDecimal : Tensor ["j" ~~> 2, "k" ~~> 5] Double + matrixDecimal = ># [ [ 0.01, 0.001, 0.0001, 0.00001, 0.000001] + , [ 10, 20, 30, 40, 50] ] + + public export + matrixDecimal2 : Tensor ["j" ~~> 2, "k" ~~> 5] Double + matrixDecimal2 = ># [ [1000, 10000, 100000, 1000000, 10000000] + , [40, 50, 60, 70, 80]] + + +namespace TreeTensors + public export + treeExample1 : Tensor ["myTree" ~> BinTree] Double + treeExample1 = ># Node 60 (Node 7 (Leaf (-42)) (Leaf 46)) (Leaf 2) + + public export + treeExample2 : Tensor ["myTree" ~> BinTree] Double + treeExample2 = ># Node 5 (Leaf 100) (Leaf 4) + + public export + treeExample3 : Tensor ["myTree" ~> BinTreeNode, "j" ~> Vect 2] Double + treeExample3 = ># Node [4,1] (Node [17, 4] Leaf' Leaf') Leaf' + + public export + treeExample4 : Tensor ["myTree" ~> BinTreeNode, + "myTreeLeaf" ~> BinTreeLeaf, + "k" ~> Vect 3] Double + treeExample4 = ># + Node (Node' + (Leaf [1,2,3]) + (Leaf [4,5,6])) + Leaf' + (Node (Leaf [178, -43, 63]) Leaf' Leaf') + + public export + treeExample5 : Tensor ["myTree" ~> BinTreeLeaf, "v" ~~> 2] Double + treeExample5 = ># Node' (Node' (Leaf [1, -1]) + (Node' (Leaf [0.5, 1.2]) + (Leaf [0.3, -0.2]))) + (Leaf [-0.3, 1.2]) + + public export + treeExample6 : Tensor ["myTree" ~> BinTree, "j" ~~> 300] Double + treeExample6 = ># Node (replicate 300 3) (Leaf $ replicate 300 1) (Leaf $ replicate 300 2) + + public export + treeExample7 : Tensor ["myTree" ~> BinTreeLeaf] Double + treeExample7 = ># Node' (Node' (Leaf (-42)) (Leaf 46)) (Leaf 2) + + public export + treeExample8 : Tensor ["myTree" ~> BinTreeNode] Double + treeExample8 = ># Node 60 (Node 7 Leaf' Leaf') Leaf' + + +namespace ListTensors + public export + listExample1 : Tensor ["i" ~> List, "i" ~> List] Double + listExample1 = ># [ [1,2,3] + , [4,5,6] ] + + public export + listExample2 : Tensor ["i" ~~> 3, "j" ~> List] Double + listExample2 = ># [ [1,2,3] + , [4] + , [5,6,7,8] ] + + public export + listExample3 : Tensor ["i" ~> List, "j" ~~> 2] Double + listExample3 = ># [ [1,2] + , [3,4] + , [5,6] + , [7,8] ] + + + public export + listExample4 : Tensor ["i" ~> List, "j" ~> List, "k" ~> List] Double + listExample4 = ># [ [[1,2,3,4], [6], [7,8,9]] + , [[4,5], [9,9,9,9,9], [4,2,0], [999]]] + + public export + listExample5 : Tensor ["i" ~> List, "j" ~> BinTreeLeaf] Double + listExample5 = ># [ Leaf 4 + , Node' (Leaf 3) (Leaf 4) + , Node' (Node' (Leaf 5) (Leaf 6)) (Leaf 7) + , Node' (Leaf 7) (Leaf 8) ] + + +separator : String +separator = "-------------------------------------------------------" + + +public export +printAllTestInstances : IO () +printAllTestInstances = do + printLn tensorVector + putStrLn separator + printLn tensorMatrix + putStrLn separator + printLn tensor3D + putStrLn separator + printLn tensor4D + putStrLn separator + printLn tensor5D + putStrLn separator + printLn longVector + putStrLn separator + printLn longVectorReshaped + putStrLn separator + printLn longVector2 + putStrLn separator + printLn vectorDecimal + putStrLn separator + printLn matrixDecimal + putStrLn separator + printLn matrixDecimal2 + putStrLn separator + printLn treeExample1 + putStrLn separator + printLn treeExample2 + putStrLn separator + printLn treeExample3 + putStrLn separator + printLn treeExample4 + putStrLn separator + printLn treeExample5 + putStrLn separator + printLn treeExample6 + putStrLn separator + printLn treeExample7 + putStrLn separator + printLn treeExample8 + putStrLn separator + printLn listExample1 + putStrLn separator + printLn listExample2 + putStrLn separator + printLn listExample3 + putStrLn separator + printLn listExample4 + putStrLn separator + printLn listExample5 diff --git a/tests/src/Main.idr b/tests/src/Main.idr new file mode 100644 index 0000000..89c855e --- /dev/null +++ b/tests/src/Main.idr @@ -0,0 +1,9 @@ +module Main + +import Hedgehog +import Display2D.Instances +import Display2D.Expected + +public export +main : IO () +main = runTests \ No newline at end of file diff --git a/tests/tensortype-tests.ipkg b/tests/tensortype-tests.ipkg index a810238..e65e64a 100644 --- a/tests/tensortype-tests.ipkg +++ b/tests/tensortype-tests.ipkg @@ -14,18 +14,20 @@ authors = "Bruno Gavranovic" -- packages to add to search path depends = tensortype + , hedgehog -- modules to install -modules = Tensor.Basic, - Tensor.Interfaces +modules = Display2D.Instances + , Display2D.Expected + , Main -- main file (i.e. file to load at REPL) --- main = +main = Main -- name of executable --- executable = +executable = "tensortype-tests" -- opts = --- sourcedir = "src +sourcedir = "src" -- builddir = -- outputdir =