Statically typed Vector and Matrix algebra
2012-01-12
I took a course on Machine Learning a while back, just for fun. Stanford University seems to have really put some effort into free courses in the Internet, since Artificial Intelligence and Introduction to Databases were also time well spent.
The exercises in Octave reminded me of Matlab and the pain I had to endure while studying in Tampere University of Technology. It would be so much easier if the editor actually complained when trying to multiply matrices of incompatible dimensions etc. Could Scala perhaps be used to provide some static type safety to matrix operations?
Well, you guessed it.
Creating a linear algebra library wouldn't be the most exciting project (well, not this time, anyway...) so I decided to make a thin wrapper for Scalala. I was not striving for a full-featured library, but instead more like a proof-of-concept, so I only implemented a few operations.
I'm not smart enough to come up with the required type algebra, so I shamelessly copied the hard parts from here. Hopefully some day I'm going to understand all those lines that went almost unmodified through my clipboard...
Haskell is another great language and can provide more or less similar type safety. Since I don't speak Haskell that well, I'll let you read about a Haskell implementation here.
I made my code available in GitHub. Feel free to use it as you will. The name net.lahteenmaki.scalam is due to my total lack of imagination, sorry about that.
Here's a demo. First of all, only a single import is needed to use all the functionality:
> import net.lahteenmaki.scalam._
scalaimport net.lahteenmaki.scalam._
Create some regular vectors containing integers:
> val v2 = Vector(1,2)
scala: net.lahteenmaki.scalam.RowVector[Int,D2] = 1 2
v2
> val v3 = Vector(1,2,3)
scala: net.lahteenmaki.scalam.RowVector[Int,D3] = 1 2 3 v3
or doubles. Actually anything
scalala.scalar.Scalar[T]
:
> Vector(1.0,2.0)
scala: net.lahteenmaki.scalam.RowVector[Double,D2] = 1.00000 2.00000 res1
Trying to create a vector with differing element types gives a compiler error:
> Vector(1,2.0)
scala<console>:11: error: T is not a scalar value
Vector(1,2.0)
^
Transposing a row vector creates a column vector of the same dimension:
> v2.T
scala: net.lahteenmaki.scalam.ColumnVector[Int,D2] =1 2 res3
I included some implicits to create vectors from tuples:
> (1,2).T
scala: net.lahteenmaki.scalam.ColumnVector[Int,D2] =1 2 res4
There's nothing special in scalar multiplication, except that the element types change similar to Scalala:
> v2*2
scala: net.lahteenmaki.scalam.RowVector[Int,D2] = 2 4
res5
> v2*2.0
scala: net.lahteenmaki.scalam.RowVector[Double,D2] = 2.00000 4.00000 res6
Addition should retain the dimensions and be only allowed to vectors of the same dimension:
> v2 + v2
scala: net.lahteenmaki.scalam.RowVector[Int,D2] = 2 4
res7
> Vector(1,2) + Vector(1.0,2.0)
scala: net.lahteenmaki.scalam.RowVector[Double,Succ[Succ[D0]]] = 2.00000 4.00000
res8
> v2 + v3
scala<console>:13: error: overloaded method value + with alternatives:
[B](other: net.lahteenmaki.scalam.RowVector[B,D2])
(implicit o: v2.BinOp[B,scalala.operators.OpAdd])
.lahteenmaki.scalam.RowVector[B,D2]
net<and>
[B](other: net.lahteenmaki.scalam.Matrix[B,D1,D2])
(implicit o: v2.BinOp[B,scalala.operators.OpAdd])
.lahteenmaki.scalam.Matrix[B,D1,D2]
netto (net.lahteenmaki.scalam.RowVector[Int,D3])
cannot be applied + v3
v2 ^
Yes, we did get a compile time error. Splendid.
Vector multiplication is also only defined for compatible sizes:
> v2 * v2.T
scala: net.lahteenmaki.scalam.Matrix[Int,D1,D1] = 5
res10
> v2 * v2
scala<console>:12: error: Could not find a way to values of type
.lahteenmaki.scalam.RowVector[Int,D2] and scalala.operators.OpMulMatrixBy
net* v2
v2 ^
> v2 * v3
scala<console>:13: error: Could not find a way to values of type
.lahteenmaki.scalam.RowVector[Int,D3] and scalala.operators.OpMulMatrixBy
net* v3
v2 ^
Again, the compiler won't let me multiply a row vector with another one. Nice.
How about concatenating vectors? :
> v2 ++ v3
scala: net.lahteenmaki.scalam.RowVector[Int,Add[D2,D3]] = 1 2 1 2 3
res13
> val v: RowVector[Int,D5] = v2 ++ v3
scala: net.lahteenmaki.scalam.RowVector[Int,D5] = 1 2 1 2 3
v
> v2 ++ v2.T
scala<console>:12: error: type mismatch;
: net.lahteenmaki.scalam.ColumnVector[Int,D2]
found : net.lahteenmaki.scalam.Matrix[Int,D1,?]
required++ v2.T
v2 ^
The compiler can deduce the dimension of the result, and won't let me concatenate a row vector with a column vector. Just what I wanted.
Then the classic over-indexing case:
> v2[D1]
scala: Int = 1
res15
> v2[D2]
scala: Int = 2
res16
> v2[D3]
scala<console>:12: error: Cannot prove that
[D2]#Match[True,True,False,Bool] =:= True.
D3#Compare[D3]
v2^
Spectacular. The compiler won't let me get an element n+1 from an n-dimensional vector.
Same operations can be implemented for matrices, as well as some helper methods for constructing simple matrices:
> val m22 = Matrix.ones[Int,D2]
scala: net.lahteenmaki.scalam.Matrix[Int,D2,D2] =
m221 1
1 1
> val m23 = Matrix.ones[Int,D2,D3]
scala: net.lahteenmaki.scalam.Matrix[Int,D2,D3] =
m231 1 1
1 1 1
> Matrix.zeros[Double,D2]
scala: net.lahteenmaki.scalam.Matrix[Double,D2,D2] =
res180.00000 0.00000
0.00000 0.00000
> Matrix.rand[D5,D5]
scala: net.lahteenmaki.scalam.Matrix[Int,D5,D5] =
res198 6 10 2 2
3 2 11 1 15
10 1 18 9 5
11 5 8 10 18
0 17 2 12 24
> m22.T
scala: net.lahteenmaki.scalam.Matrix[Int,D2,D2] =
res201 1
1 1
> m22 + m22
scala: net.lahteenmaki.scalam.Matrix[Int,D2,D2] =
res212 2
2 2
> m22 + m23
scala<console>:13: error: type mismatch;
: net.lahteenmaki.scalam.Matrix[Int,D2,D3]
found : net.lahteenmaki.scalam.Matrix[?,D2,D2]
required+ m23
m22 ^
> m22 * 5.5
scala: net.lahteenmaki.scalam.Matrix[Double,D2,D2] =
res235.50000 5.50000
5.50000 5.50000
> m22 * m23
scala: net.lahteenmaki.scalam.Matrix[Int,D2,D3] =
res242 2 2
2 2 2
> m22 * v2
scala<console>:13: error: Could not find a way to values of type
.lahteenmaki.scalam.RowVector[Int,D2] and scalala.operators.OpMulMatrixBy
net* v2
m22 ^
> v3 * Matrix.rand[D1,D5]
scala<console>:12: error: Could not find a way to values of type
.lahteenmaki.scalam.Matrix[Int,D1,D5] and scalala.operators.OpMulMatrixBy
net* Matrix.rand[D1,D5]
v3 ^
> m23 * m22
scala<console>:13: error: Could not find a way to values of type
.lahteenmaki.scalam.Matrix[Int,D2,D2] and scalala.operators.OpMulMatrixBy
net* m22
m23 ^
> m23[D1,D1]
scala: Int = 1
res28
> m23[D2,D3]
scala: Int = 1
res29
> m23[D3,D3]
scala<console>:12: error: Cannot prove that
[D2]#Match[True,True,False,Bool] =:= True.
D3#Compare[D3,D3]
m23^
Everything is working for small vectors and matrices, but how about bigger ones? I actually only declared dimensions from D1 to D22, but one could always declare more, probably generate them:
> val v7 = Vector(1,2,3,4,5,6,7)
scala: net.lahteenmaki.scalam.RowVector[Int,D7] = 1 2 3 4 5 6 7
v7
> val v21 = v7 ++ v7 ++ v7
scala: net.lahteenmaki.scalam.RowVector[Int,Add[Add[D7,D7],D7]] =
v211 2 3 4 5 6 7 1 2 3 4 5 6 7 1 2 3 4 5 6 7
> val v23 = v21 ++ Vector(22,23)
scala: net.lahteenmaki.scalam.RowVector[Int,Add[Add[Add[D7,D7],D7],D2]] =
v231 2 3 4 5 6 7 1 2 3 4 5 6 7 1 2 3 4 5 6 7 22 23
> v23[D23]
scala<console>:14: error: not found: type D23
[D23]
v23^
<console>:14: error: Cannot prove that
(Add[Add[Add[D7,D7],D7],D2],)#Match[True,True,False,Bool] =:= True.
[D23]
v23^
> type D23 = Succ[D22]
scalatype alias D23
defined
> v23[D23]
scala: Int = 23 res32
So, this is nice. Almost too good to be true?
There are some issues, of course. You probably noticed already in the
beginning that the produced error messages aren't exactly helpful for an
average programmer. This might be improved if Scala introduced
more features like @implicitNotFound
that could be used to
provide the compiler with custom error messages.
Also, in cases where the dimension changes, the compiler cannot
deduce the resulting dimension, but instead gives out the cryptic
Add[Add[...]]
signatures which need to be manually casted
to "readable" signatures, if needed. This might be just an issue with my
implementation, though, I don't know.
Perhaps the biggest problem might turn out to be performance. Compiling Scala is already a heavy job, and handling types for a 10000x10000 matrix might just be beyond any possible compiler optimizations.