Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 169 additions & 79 deletions README.md

Large diffs are not rendered by default.

60 changes: 35 additions & 25 deletions examples/Attention.idr
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,60 @@ Attention example
Will run self attention as usual, on matrices, and then on trees
-------------------------------------------------------------------------------}

||| We start by dealing with ordinary matrices, and define the matrix axes

||| The length of the input sequence
SeqLen : Axis
SeqLen = "seqLen" ~~> 3

||| The number of tokens for each element in the sequence
NumTokens : Axis
NumTokens = "numTokens" ~~> 4

||| We'll first instantiate self attention as a parametric map on matrices
SelfAttentionMat : {n, d : Nat} ->
{default False causalMask : Bool} ->
Tensor ["seqLen" ~~> n, "numTokens" ~~> d] Double -\->
Tensor ["seqLen" ~~> n, "numTokens" ~~> d] Double
SelfAttentionMat : {default False causalMask : Bool} ->
Tensor [SeqLen, NumTokens] Double -\->
Tensor [SeqLen, NumTokens] Double
SelfAttentionMat {causalMask} = case causalMask of
False => SelfAttention softargmaxImpl
True => SelfAttention {causalMask=Attention.causalMask} softargmaxImpl

||| Let's fix a simple input matrix
inputMatrix : Tensor ["seqLen" ~~> 3, "numTokens" ~~> 2] Double
inputMatrix = ># [ [1, 3]
, [2, -3]
, [0, 0.3]]
inputMatrix : Tensor [SeqLen, NumTokens] Double
inputMatrix = ># [ [1, 3, 3, 2]
, [2, -3, 2, 1]
, [0, 0.3, 10, 9]]

||| Let's fix attention parameters for the query, key and value matrices.
||| For instance, a matrix of ones, a triangular matrix, and a matrix of threes
params : {d : Nat} -> SelfAttentionParams ("numTokens" ~~> d) {a=Double}
params : SelfAttentionParams NumTokens {a=Double}
params = MkSAParams ones tri (ones <&> (*3))

||| Now we can run self attention on the input matrix
||| This value can be inspected in REPL, or otherwise
outputMatrix : Tensor ["seqLen" ~~> 3, "numTokens" ~~> 2] Double
||| This value can be inspected in REPL via `:exec printLn outputMatrix`
outputMatrix : Tensor [SeqLen, NumTokens] Double
outputMatrix = Run (SelfAttentionMat {causalMask=True}) inputMatrix params


||| Now we'll instantiate self attention as a parametric map on trees and use
||| container tensors for this. Here we'll study attention where the input
||| structure isn't a sequence, but a tree, but we'll keep the feature structure
||| as a sequence
||| That is, instead of `CTensor [Vect n, Vect d] Double`
||| we'll have `CTensor [BinTreeLeaf, Vect d] Double`
SelfAttentionTree : {d : Nat} ->
Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect d] Double -\->
Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect d] Double
||| Now we instantiate self attention not operating on a vector of vectors
||| (i.e. a matrix), but a *tree* of vectors. We'll first define the axis
InputStructure : Axis
InputStructure = "inputStructure" ~> BinTreeLeaf

||| We keep features as a vector, and define the self-attention map
||| Notably, we do not use any causal mask
SelfAttentionTree : Tensor [InputStructure, NumTokens] Double -\->
Tensor [InputStructure, NumTokens] Double
SelfAttentionTree = SelfAttention softargmaxImpl

||| We fix a simple input tree
||| Notably, the set of parameters can be the same as the one for matrices
inputTree : Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect 2] Double
inputTree = ># Node' (Node' (Leaf [1, -1])
(Leaf [0.5, 1.2]))
(Leaf [-0.3, 1.2])
inputTree : Tensor [InputStructure, NumTokens] Double
inputTree = ># Node' (Node' (Leaf [1, -1, 3, 2])
(Node' (Leaf [0.5, 1.2, 2, 1])
(Leaf [0.3, 4, 10, 9])))
(Leaf [-0.3, 1.2, -13, -0.3])

||| We can run self attention on the tree, and inspect the result
outputTree : Tensor ["inputStructure" ~> BinTreeLeaf, "numTokens" ~> Vect 2] Double
outputTree : Tensor [InputStructure, NumTokens] Double
outputTree = Run SelfAttentionTree inputTree params
61 changes: 40 additions & 21 deletions examples/BasicExamples.idr
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@ import Data.Tensor
-- Examples of standard, cubical tensors
----------------------------------------

||| Now you can construct Tensors directly:
t0 : Tensor ["j" ~~> 3, "k" ~~> 4] Double
t0 = ># [ [0, 1, 2, 3]
, [4, 5, 6, 7]
, [8, 9, 10, 11]]
||| This is a vector of size `5`, with its axis named "i".
myVector : Tensor ["i" ~~> 5] Double
myVector = ># [1,2,3,4,5]

||| This is a matrix of size `3x4`, with its axes named "j" and "k".
myMatrix : Tensor ["j" ~~> 3, "k" ~~> 4] Double
myMatrix = ># [ [0, 1, 2, 3]
, [4, 5, 6, 7]
, [8, 9, 10, 11]]

{--------------------
Here `>#` behaves like a constructor: it takes a concrete value and turns it into the tensor of the appropriate shape (It should be visually read as a 'map' (`>`) into 'tensor' (`#`)).

You can also use functions analogous to numpy's, such as `np.arange` and `np.reshape`:
--------------------}
t1 : Tensor ["i" ~~> 6] Double
t1 : Tensor ["l" ~~> 6] Double
t1 = arange

t2 : Tensor ["i" ~~> 2, "j" ~~> 3] Double
t2 = reshape t1


{-
where the difference between numpy is that these operations are typechecked
- meaning they fail at compile-time if you supply an array with the wrong shape.
Expand All @@ -49,20 +54,20 @@ failing

||| You can perform all sorts of familiar numeric operations:
exampleSum : Tensor ["j" ~~> 3, "k" ~~> 4] Double
exampleSum = t0 + t0
exampleSum = myMatrix + myMatrix

exampleOp : Tensor ["j" ~~> 3, "k" ~~> 4] Double
exampleOp = abs (- (t0 * t0) <&> (+7))
exampleOp = abs (- (myMatrix * myMatrix) <&> (+7))

||| including standard linear algebra
dotExample : Tensor [] Double
dotExample = dot t1 (t1 <&> (+5))

matMulExample : Tensor ["i" ~~> 2, "k" ~~> 4] Double
matMulExample = matMul t2 t0
matMulExample = matMul t2 myMatrix

transposeExample : Tensor ["k" ~~> 4, "j" ~~> 3] Double
transposeExample = transposeMatrix t0
transposeExample = transposeMatrix myMatrix

{--------------------
which all have their types checked at compile-time. For instance, you can't
Expand All @@ -71,25 +76,25 @@ dimensions of matrices don't match.
--------------------}
failing
sumFail : Tensor ["j" ~~> 3, "k" ~~> 4] Double
sumFail = t0 + t1
sumFail = myMatrix + t1

failing
matMulFail : Tensor ["i" ~~> 7] Double
matMulFail = matMul t0 t1
matMulFail = matMul myMatrix t1

||| Like in numpy, you can safely index into tensors, set values of tensors, and perform slicing:
||| This retrieves the value of t- at location [1,2]
indexExample : Double
indexExample = t0 @@ [1, 2]
indexExample = myMatrix @@ [1, 2]

-- TODO needs to be fixed
-- ||| Sets the value of t0 at location [1, 3] to 99
-- ||| Sets the value of `myMatrix` at location [1, 3] to 99
-- setExample : Tensor [3, 4]
-- setExample = set t0 [1, 3] 99
-- setExample = set `myMatrix` [1, 3] 99

-- ||| Takes the first two rows, and 1st column of t0
-- sliceExample : Tensor ["j" ~~> 2, "k" ~~> 1] Double
-- sliceExample = take [2, 1] t0
-- sliceExample = take [2, 1] `myMarix`

-- Which will all fail if you go out of bounds
failing
Expand All @@ -98,16 +103,16 @@ failing

failing
sliceFail : Tensor ["j" ~~> 10, "k" ~~> 2] Double
sliceFail = take [10, 2] t0
sliceFail = take [10, 2] myMatrix

{---------------------------------------
**And most importantly, you can do all of this with *non-cubical* tensors.** These describe tensors whose shape isn't rectangular/cubical, but can be branching/recursive/higher-order.
That is, instead of binding an axis name to a number, we bind it to something called a "container", by using `~>` instead of `~~>`.
As a matter of fact, `~~>` behind the scenes desugars to `~>`, and we have been using this all along.
Let's see `t0` in this new form:
Let's see `myMatrix` in this new form:
---------------------------------------}
t0Again : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double
t0Again = t0
myMatrixAgain : Tensor ["j" ~> Vect 3, "k" ~> Vect 4] Double
myMatrixAgain = myMatrix

{--------------------
Here `Vect` does not refer to `Vect` from `Data.Vect`, but rather the `Vect` container implemented [here](https://github.com/bgavran/TensorType/blob/main/src/Data/Container/Object/Instances.idr#L68).
Expand Down Expand Up @@ -139,6 +144,13 @@ Unlike `Vect`, this container allows us to store an arbitrary number of elements
treeExample2 : Tensor ["myTree" ~> BinTree] Double
treeExample2 = ># Node 5 (Leaf 100) (Leaf 4)


listEx : List' (Tensor ["myTree" ~> BinTree] Double)
listEx = ># [ treeExample1, treeExample2 ]

vectEx : Vect' 2 (Vect' 3 Double)
vectEx = ># [ ># [1, 2, 3], ># [4, 5, 6.000290940000203] ]

{--------------------
Perhaps surpisingly, all linear algebra operations follow smoothly. The example below is the _dot product of trees_. The fact that these trees don't have the same number of elements is irrelevant; what matters is that the container defining them (`BinTree`) is the same.
--------------------}
Expand Down Expand Up @@ -183,6 +195,13 @@ treeExample4 = >#
Leaf'
(Node (Leaf [178, -43, 63]) Leaf' Leaf')


treeExample5 : Tensor ["myTree" ~> BinTreeLeaf, "v" ~~> 2] Double
treeExample5 = ># Node' (Node' (Leaf [1, -1])
(Node' (Leaf [0.5, 1.2])
(Leaf [0.3, -0.2])))
(Leaf [-0.3, 1.2])

{--------------------
For instance, we can index into `treeExample1`:
60
Expand Down Expand Up @@ -218,5 +237,5 @@ Here is the in-order traversal of `treeExample1` from above.

Can also use Utils.Traversals.inorder
--------------------}
traversalExample : Tensor ["myTree" ~> List] Double
traversalExample : Tensor ["flattenedTree" ~> List] Double
traversalExample = restructure (wrapIntoVector inorder) treeExample1
8 changes: 5 additions & 3 deletions src/Control/Monad/Sample/Instances.idr
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ public export
||| Computes the cumulative distribution, samples randomly, finds the right bin
public export
MonadSample IO where
sample @{ItIsSucc} (MkDist xs) = do
let dist : Tensor ["dist" ~~> i] Double := (softargmaxImpl {i="dist" ~~> i}) (># xs)
cumSum : Tensor ["dist" ~~> i] Double := cumulativeSum dist
sample {i = S j} (MkDist xs) = do
let dist : Tensor ["dist" ~~> S j] Double
dist = softargmaxImpl {i="dist" ~~> S j} (># xs)
cumSum : Tensor ["dist" ~~> S j] Double
cumSum = cumulativeSum dist
r <- randomRIO (0.0, 1.0)
case findBin (#> cumSum) r of
Nothing => pure FZ -- should never happen!
Expand Down
8 changes: 7 additions & 1 deletion src/Data/CT/Category/Instances.idr
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Data.CT.Category.Instances
import Data.CT.Category.Definition
import Data.CT.Functor.Definition

import Data.ComMonoid
import Data.Container.Base
import Data.Container.Additive

Expand Down Expand Up @@ -34,4 +35,9 @@ AddDLens = MkCat AddCont (=%>)
||| Category of additive dependent charts
public export
AddDChart : Cat
AddDChart = MkCat AddCont (=&>)
AddDChart = MkCat AddCont (=&>)

||| Category of commutative monoids and commutative monoid homomorphisms
public export
ComMon : Cat
ComMon = MkCat ComMonoid ComMonoidHomo
4 changes: 2 additions & 2 deletions src/Data/CT/DependentAction/Instances.idr
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace Cont
public export
DPairCont : DepAct DLens (FamDLens {c=DLens})
DPairCont = MkDepAct $ \c => MkFunctor
(DepHancockProduct c)
(DPairTensor c)
(\r => !% \(x ** p) => ((x ** (r x).fwd p) **
\(x', p') => (x', (r x).bwd p p')))

Expand All @@ -49,6 +49,6 @@ namespace AddCont
public export
DPairAddCont : DepAct AddDLens (FamAddDLens {c=AddDLens})
DPairAddCont = MkDepAct $ \c => MkFunctor
(DepHancockProduct c)
(DPair c)
(\r => !%+ \(x ** p) => ((x ** (r x).fwd p) **
\(x', p') => (x', (r x).bwd p p')))
12 changes: 6 additions & 6 deletions src/Data/CT/DependentPara/Instances.idr
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ namespace DependentParametricDependentLenses
public export
composePara : a =\\=> b -> b =\\=> c -> a =\\=> c
composePara (MkPara p f) (MkPara q g) = MkPara
(\x => DepHancockProduct (p x) (\ps => q (f.fwd (x ** ps))))
(\x => DPair (p x) (\ps => q (f.fwd (x ** ps))))
(!%+ \(x ** (ps ** qs)) =>
(g.fwd (f.fwd (x ** ps) ** qs) ** \cPos =>
let (bPos, qPos) = g.bwd (f.fwd (x ** ps) ** qs) cPos
Expand All @@ -211,12 +211,12 @@ namespace DependentParametricDependentLenses
||| a non-dependent (constant) parameter.
public export
data IsNotDependent : DParaAddDLens a b -> Type where
MkNonDep : (p : AddCont) -> (f : DepHancockProduct a (const p) =%> b) ->
MkNonDep : (p : AddCont) -> (f : DPair a (const p) =%> b) ->
IsNotDependent {a=a} {b=b} (MkPara (\_ => p) f)

public export
GetNonDep : (pf : DParaAddDLens a b) ->
IsNotDependent pf => (pc : AddCont ** DepHancockProduct a (const pc) =%> b)
IsNotDependent pf => (pc : AddCont ** DPair a (const pc) =%> b)
GetNonDep _ @{MkNonDep pc f} = (pc ** f)

public export
Expand All @@ -237,11 +237,11 @@ namespace DependentParametricDependentLenses
composeNTimes 1 f = f -- to get rid of the annoying Unit parameter
composeNTimes (S k) f = composePara f (composeNTimes k f)

||| Convert a morphism from product container to one from DepHancockProduct
||| This witnesses the isomorphism (a >< p) ≅ DepHancockProduct a (const p)
||| Convert a morphism from product container to one from DPair
||| This witnesses the isomorphism (a >< p) ≅ DPair a (const p)
public export
fromNonDepProduct : {0 a, p, b : AddCont} ->
(a >< p) =%> b -> DepHancockProduct a (const p) =%> b
(a >< p) =%> b -> DPair a (const p) =%> b
fromNonDepProduct f = !%+ \(x ** p') => (%!) f (x, p')

public export
Expand Down
41 changes: 2 additions & 39 deletions src/Data/CT/Functor/Instances.idr
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ namespace Type
TypeDFun fam = (x : a) -> fam x

namespace Cont
||| TODO probably name clash with other "Indexed container"
public export
IndexedCont : Cont -> Type
IndexedCont c = c.Shp -> Cont
Expand Down Expand Up @@ -180,42 +181,4 @@ trivialFam = \_ => ((_ : ()) !> Void)
-- Family that assigns Bool shapes with no positions
public export
boolFam : {c : Cont} -> IndexedCont c
boolFam = \_ => ((_ : Bool) !> Void)

--------------------------------------------------------------------------------
-- COMPARISON WITH CHARTS/LENSES
--
-- A chart `c =&> d` is a MORPHISM in Poly, involving both shapes AND positions.
-- A family `c.Shp -> Cont` is just a FUNCTION on shapes.
--
-- Charts give richer structure (relating positions), but families are simpler
-- and sufficient for forming dependent sums and products.
--
-- You CAN extract a family from certain charts:
--------------------------------------------------------------------------------

-- The "universe" container: shapes are types, positions are their elements
public export
ContUniverse : Cont
ContUniverse = (t : Type) !> t

-- From a chart to Universe, extract just the shape-level data as a family
-- (This loses the position-level information from the chart!)
public export
famFromChart : (c : Cont) -> (f : c =&> ContUniverse) -> IndexedCont c
famFromChart c f = \s => ((_ : f.fwd s) !> Void)
-- We get family shapes from f.fwd, but lose the f.bwd data
-- The chart's f.bwd : (s : c.Shp) -> c.Pos s -> f.fwd s
-- doesn't fit into ContFam's structure

--------------------------------------------------------------------------------
-- WHY Poly ISN'T LOCALLY CARTESIAN CLOSED
--
-- In a locally cartesian closed category, for any morphism f : A -> B,
-- the pullback functor f* : C/B -> C/A has a right adjoint Πf.
--
-- In Poly, this fails for general dependent lenses. The right adjoint
-- only exists for CARTESIAN morphisms (where the backward map is an iso).
--
-- Reference: von Glehn thesis, Section 4.3
--------------------------------------------------------------------------------
boolFam = \_ => ((_ : Bool) !> Void)
Loading
Loading