From 78b998e1823a196dab5fbcff547568a06242fbad Mon Sep 17 00:00:00 2001 From: Simon Parten Date: Thu, 26 Dec 2024 22:09:10 +0100 Subject: [PATCH] . --- build.mill | 9 ++- vecxt/src/dimMatCheck.scala | 8 +-- vecxtensions/js-jvm/src/matmul.scala | 50 +++++++++++++ vecxtensions/test/jvm/src/matrix.test.scala | 80 +++++++++++++++++++++ 4 files changed, 140 insertions(+), 7 deletions(-) create mode 100644 vecxtensions/js-jvm/src/matmul.scala create mode 100644 vecxtensions/test/jvm/src/matrix.test.scala diff --git a/build.mill b/build.mill index 0dbfec2..edd1c9e 100644 --- a/build.mill +++ b/build.mill @@ -102,6 +102,7 @@ object vecxt extends CrossPlatform { object vecxtensions extends CrossPlatform { override def moduleDeps: Seq[CrossPlatform] = Seq(vecxt) trait Shared extends CrossPlatformScalaModule with Common { + // common `core` settings here trait SharedTests extends CommonTests { // common `core` test settings here @@ -110,10 +111,11 @@ object vecxtensions extends CrossPlatform { object jvm extends Shared { override def javacOptions: T[Seq[String]] = super.javacOptions() ++ vecIncubatorFlag def forkArgs = super.forkArgs() ++ vecIncubatorFlag - def ivyDeps = super.ivyDeps() ++ Agg( - + def ivyDeps = Agg( + ivy"org.typelevel::spire::0.18.0" ) + object test extends ScalaTests with SharedTests { def forkArgs = super.forkArgs() ++ vecIncubatorFlag } @@ -121,7 +123,8 @@ object vecxtensions extends CrossPlatform { object js extends Shared with CommonJS { override def ivyDeps: Target[Agg[Dep]] = super.ivyDeps() ++ Agg( ivy"com.lihaoyi::scalatags::0.13.1", - ivy"com.raquo::laminar::17.1.0" + ivy"com.raquo::laminar::17.1.0", + ivy"org.typelevel::spire::0.18.0" ) // js specific settings here object test extends ScalaJSTests with SharedTests { diff --git a/vecxt/src/dimMatCheck.scala b/vecxt/src/dimMatCheck.scala index be99c5b..e787de7 100644 --- a/vecxt/src/dimMatCheck.scala +++ b/vecxt/src/dimMatCheck.scala @@ -6,18 +6,18 @@ import vecxt.matrix.* import narr.* -protected[vecxt] object dimMatCheck: +object dimMatCheck: inline def apply[A](a: Matrix[A], b: Matrix[A])(using inline doCheck: BoundsCheck) = inline if doCheck then if a.cols != b.rows then throw MatrixDimensionMismatch(a.rows, a.cols, b.rows, b.cols) end dimMatCheck -protected[vecxt] object sameDimMatCheck: +object sameDimMatCheck: inline def apply[A, B](a: Matrix[A], b: Matrix[B])(using inline doCheck: BoundsCheck) = inline if doCheck then if !(a.cols == b.cols && a.rows == b.rows) then throw MatrixDimensionMismatch(a.rows, a.cols, b.rows, b.cols) end sameDimMatCheck -protected[vecxt] object indexCheckMat: +object indexCheckMat: inline def apply[A](a: Matrix[A], dim: RowCol)(using inline doCheck: BoundsCheck) = inline if doCheck then if !(dim._1 >= 0 && dim._2 >= 0 && dim._1 <= a.rows && dim._2 <= a.cols) then @@ -26,7 +26,7 @@ protected[vecxt] object indexCheckMat: ) end indexCheckMat -protected[vecxt] object dimMatInstantiateCheck: +object dimMatInstantiateCheck: inline def apply[A](raw: NArray[A], dim: RowCol)(using inline doCheck: BoundsCheck) = inline if doCheck then if dim._1 * dim._2 != raw.size diff --git a/vecxtensions/js-jvm/src/matmul.scala b/vecxtensions/js-jvm/src/matmul.scala new file mode 100644 index 0000000..44f600b --- /dev/null +++ b/vecxtensions/js-jvm/src/matmul.scala @@ -0,0 +1,50 @@ +package vecxtensions + +import spire.implicits.* +import spire.algebra.Ring +import scala.reflect.ClassTag +import vecxt.* +import vecxt.all.* +import vecxt.BoundsCheck.BoundsCheck +import narr.* + +object SpireExt: + + extension [A: ClassTag: Ring](m1: Matrix[A]) + inline def @@@( + m2: Matrix[A] + )(using inline boundsCheck: BoundsCheck): Matrix[A] = + dimMatCheck(m1, m2) + val (r1, c1) = m1.shape + val (r2, c2) = m2.shape + + val nar = NArray.ofSize[A](r1 * c2) + val res = Matrix(nar, (r1, c2)) + + for i <- 0 until r1 do + for j <- 0 until c2 do + res((i, j)) = (0 until c1) + .map { k => + val i1 = m1((i: Row, k: Col)) + val i2 = m2((k: Row, j: Col)) + i1 * i2 + } + .reduce(_ + _) + end for + res + end @@@ + + inline def showMat: String = + val (r, c) = m1.shape + val sb = new StringBuilder + for i <- 0 until r do + for j <- 0 until c do + sb.append(m1((i: Row, j: Col))(using BoundsCheck.DoBoundsCheck.no)) + sb.append(" ") + end for + sb.append("\n") + end for + sb.toString + end showMat + end extension +end SpireExt diff --git a/vecxtensions/test/jvm/src/matrix.test.scala b/vecxtensions/test/jvm/src/matrix.test.scala new file mode 100644 index 0000000..ef0b3fc --- /dev/null +++ b/vecxtensions/test/jvm/src/matrix.test.scala @@ -0,0 +1,80 @@ +package vecxtensions + +import vecxt.all.* +import spire.math.Complex +import munit.FunSuite +import vecxt.BoundsCheck +import spire.implicits.* +import narr.* + +import vecxtensions.SpireExt.* + +class MatrixExtensionSuite extends FunSuite: + + import BoundsCheck.DoBoundsCheck.yes + + def assertVecEquals[A](v1: NArray[A], v2: NArray[A])(implicit loc: munit.Location): Unit = + var i: Int = 0; + while i < v1.length do + munit.Assertions.assertEquals(v1(i), v2(i), clue = s"at index $i") + i += 1 + end while + end assertVecEquals + + test("Higher kinded matmul") { + val mat1 = Matrix.fromRows( + NArray( + NArray(1L, 2L, 3L), + NArray(4L, 5L, 6L) + ) + ) + + val mat2 = Matrix.fromRows( + NArray( + NArray(7L, 8L), + NArray(9L, 10L), + NArray(11L, 12L) + ) + ) + + val result = Matrix.fromRows( + NArray( + NArray(58L, 64L), + NArray(139L, 154L) + ) + ) + + val mult = mat1 @@@ mat2 + assertVecEquals(mult.raw, result.raw) + + } + + test("Spire matmul") { + + val mat1 = Matrix.fromRows[Complex[Double]]( + NArray( + NArray[Complex[Double]](Complex(1.0, -1.0), Complex(0.0, 2.0), Complex(-2.0, 1.0)), + NArray[Complex[Double]](Complex(0.0, -3.0), Complex(3.0, -2.0), Complex(-1.0, -1.0)) + ) + ) + + val mat2 = Matrix.fromRows[Complex[Double]]( + NArray( + NArray[Complex[Double]](Complex(0.0, -2.0), Complex(1.0, -4.0)), + NArray[Complex[Double]](Complex(-1.0, 3.0), Complex(2.0, -3.0)), + NArray[Complex[Double]](Complex(-2.0, 1.0), Complex(-4.0, 1.0)) + ) + ) + + val result = Matrix.fromRows[Complex[Double]]( + NArray( + NArray[Complex[Double]](Complex(-5.0, -8.0), Complex(10.0, -7.0)), + NArray[Complex[Double]](Complex(0.0, 12.0), Complex(-7.0, -13.0)) + ) + ) + + val mult: Matrix[Complex[Double]] = mat1 @@@ mat2 + assertVecEquals(mult.raw, result.raw) + } + +end MatrixExtensionSuite