Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Dec 26, 2024
1 parent f455b65 commit 78b998e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 7 deletions.
9 changes: 6 additions & 3 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ object vecxt extends CrossPlatform {
object vecxtensions extends CrossPlatform {
override def moduleDeps: Seq[CrossPlatform] = Seq(vecxt)
trait Shared extends CrossPlatformScalaModule with Common {

// common `core` settings here
trait SharedTests extends CommonTests {
// common `core` test settings here
Expand All @@ -110,18 +111,20 @@ object vecxtensions extends CrossPlatform {
object jvm extends Shared {
override def javacOptions: T[Seq[String]] = super.javacOptions() ++ vecIncubatorFlag
def forkArgs = super.forkArgs() ++ vecIncubatorFlag
def ivyDeps = super.ivyDeps() ++ Agg(

def ivyDeps = Agg(
ivy"org.typelevel::spire::0.18.0"
)


object test extends ScalaTests with SharedTests {
def forkArgs = super.forkArgs() ++ vecIncubatorFlag
}
}
object js extends Shared with CommonJS {
override def ivyDeps: Target[Agg[Dep]] = super.ivyDeps() ++ Agg(
ivy"com.lihaoyi::scalatags::0.13.1",
ivy"com.raquo::laminar::17.1.0"
ivy"com.raquo::laminar::17.1.0",
ivy"org.typelevel::spire::0.18.0"
)
// js specific settings here
object test extends ScalaJSTests with SharedTests {
Expand Down
8 changes: 4 additions & 4 deletions vecxt/src/dimMatCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import vecxt.matrix.*

import narr.*

protected[vecxt] object dimMatCheck:
object dimMatCheck:
inline def apply[A](a: Matrix[A], b: Matrix[A])(using inline doCheck: BoundsCheck) =
inline if doCheck then if a.cols != b.rows then throw MatrixDimensionMismatch(a.rows, a.cols, b.rows, b.cols)
end dimMatCheck

protected[vecxt] object sameDimMatCheck:
object sameDimMatCheck:
inline def apply[A, B](a: Matrix[A], b: Matrix[B])(using inline doCheck: BoundsCheck) =
inline if doCheck then
if !(a.cols == b.cols && a.rows == b.rows) then throw MatrixDimensionMismatch(a.rows, a.cols, b.rows, b.cols)
end sameDimMatCheck

protected[vecxt] object indexCheckMat:
object indexCheckMat:
inline def apply[A](a: Matrix[A], dim: RowCol)(using inline doCheck: BoundsCheck) =
inline if doCheck then
if !(dim._1 >= 0 && dim._2 >= 0 && dim._1 <= a.rows && dim._2 <= a.cols) then
Expand All @@ -26,7 +26,7 @@ protected[vecxt] object indexCheckMat:
)
end indexCheckMat

protected[vecxt] object dimMatInstantiateCheck:
object dimMatInstantiateCheck:
inline def apply[A](raw: NArray[A], dim: RowCol)(using inline doCheck: BoundsCheck) =
inline if doCheck then
if dim._1 * dim._2 != raw.size
Expand Down
50 changes: 50 additions & 0 deletions vecxtensions/js-jvm/src/matmul.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package vecxtensions

import spire.implicits.*
import spire.algebra.Ring
import scala.reflect.ClassTag
import vecxt.*
import vecxt.all.*
import vecxt.BoundsCheck.BoundsCheck
import narr.*

object SpireExt:

extension [A: ClassTag: Ring](m1: Matrix[A])
inline def @@@(
m2: Matrix[A]
)(using inline boundsCheck: BoundsCheck): Matrix[A] =
dimMatCheck(m1, m2)
val (r1, c1) = m1.shape
val (r2, c2) = m2.shape

val nar = NArray.ofSize[A](r1 * c2)
val res = Matrix(nar, (r1, c2))

for i <- 0 until r1 do
for j <- 0 until c2 do
res((i, j)) = (0 until c1)
.map { k =>
val i1 = m1((i: Row, k: Col))
val i2 = m2((k: Row, j: Col))
i1 * i2
}
.reduce(_ + _)
end for
res
end @@@

inline def showMat: String =
val (r, c) = m1.shape
val sb = new StringBuilder
for i <- 0 until r do
for j <- 0 until c do
sb.append(m1((i: Row, j: Col))(using BoundsCheck.DoBoundsCheck.no))
sb.append(" ")
end for
sb.append("\n")
end for
sb.toString
end showMat
end extension
end SpireExt
80 changes: 80 additions & 0 deletions vecxtensions/test/jvm/src/matrix.test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package vecxtensions

import vecxt.all.*
import spire.math.Complex
import munit.FunSuite
import vecxt.BoundsCheck
import spire.implicits.*
import narr.*

import vecxtensions.SpireExt.*

class MatrixExtensionSuite extends FunSuite:

import BoundsCheck.DoBoundsCheck.yes

def assertVecEquals[A](v1: NArray[A], v2: NArray[A])(implicit loc: munit.Location): Unit =
var i: Int = 0;
while i < v1.length do
munit.Assertions.assertEquals(v1(i), v2(i), clue = s"at index $i")
i += 1
end while
end assertVecEquals

test("Higher kinded matmul") {
val mat1 = Matrix.fromRows(
NArray(
NArray(1L, 2L, 3L),
NArray(4L, 5L, 6L)
)
)

val mat2 = Matrix.fromRows(
NArray(
NArray(7L, 8L),
NArray(9L, 10L),
NArray(11L, 12L)
)
)

val result = Matrix.fromRows(
NArray(
NArray(58L, 64L),
NArray(139L, 154L)
)
)

val mult = mat1 @@@ mat2
assertVecEquals(mult.raw, result.raw)

}

test("Spire matmul") {

val mat1 = Matrix.fromRows[Complex[Double]](
NArray(
NArray[Complex[Double]](Complex(1.0, -1.0), Complex(0.0, 2.0), Complex(-2.0, 1.0)),
NArray[Complex[Double]](Complex(0.0, -3.0), Complex(3.0, -2.0), Complex(-1.0, -1.0))
)
)

val mat2 = Matrix.fromRows[Complex[Double]](
NArray(
NArray[Complex[Double]](Complex(0.0, -2.0), Complex(1.0, -4.0)),
NArray[Complex[Double]](Complex(-1.0, 3.0), Complex(2.0, -3.0)),
NArray[Complex[Double]](Complex(-2.0, 1.0), Complex(-4.0, 1.0))
)
)

val result = Matrix.fromRows[Complex[Double]](
NArray(
NArray[Complex[Double]](Complex(-5.0, -8.0), Complex(10.0, -7.0)),
NArray[Complex[Double]](Complex(0.0, 12.0), Complex(-7.0, -13.0))
)
)

val mult: Matrix[Complex[Double]] = mat1 @@@ mat2
assertVecEquals(mult.raw, result.raw)
}

end MatrixExtensionSuite

0 comments on commit 78b998e

Please sign in to comment.