-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
140 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package vecxtensions | ||
|
||
import spire.implicits.* | ||
import spire.algebra.Ring | ||
import scala.reflect.ClassTag | ||
import vecxt.* | ||
import vecxt.all.* | ||
import vecxt.BoundsCheck.BoundsCheck | ||
import narr.* | ||
|
||
object SpireExt: | ||
|
||
extension [A: ClassTag: Ring](m1: Matrix[A]) | ||
inline def @@@( | ||
m2: Matrix[A] | ||
)(using inline boundsCheck: BoundsCheck): Matrix[A] = | ||
dimMatCheck(m1, m2) | ||
val (r1, c1) = m1.shape | ||
val (r2, c2) = m2.shape | ||
|
||
val nar = NArray.ofSize[A](r1 * c2) | ||
val res = Matrix(nar, (r1, c2)) | ||
|
||
for i <- 0 until r1 do | ||
for j <- 0 until c2 do | ||
res((i, j)) = (0 until c1) | ||
.map { k => | ||
val i1 = m1((i: Row, k: Col)) | ||
val i2 = m2((k: Row, j: Col)) | ||
i1 * i2 | ||
} | ||
.reduce(_ + _) | ||
end for | ||
res | ||
end @@@ | ||
|
||
inline def showMat: String = | ||
val (r, c) = m1.shape | ||
val sb = new StringBuilder | ||
for i <- 0 until r do | ||
for j <- 0 until c do | ||
sb.append(m1((i: Row, j: Col))(using BoundsCheck.DoBoundsCheck.no)) | ||
sb.append(" ") | ||
end for | ||
sb.append("\n") | ||
end for | ||
sb.toString | ||
end showMat | ||
end extension | ||
end SpireExt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package vecxtensions | ||
|
||
import vecxt.all.* | ||
import spire.math.Complex | ||
import munit.FunSuite | ||
import vecxt.BoundsCheck | ||
import spire.implicits.* | ||
import narr.* | ||
|
||
import vecxtensions.SpireExt.* | ||
|
||
class MatrixExtensionSuite extends FunSuite: | ||
|
||
import BoundsCheck.DoBoundsCheck.yes | ||
|
||
def assertVecEquals[A](v1: NArray[A], v2: NArray[A])(implicit loc: munit.Location): Unit = | ||
var i: Int = 0; | ||
while i < v1.length do | ||
munit.Assertions.assertEquals(v1(i), v2(i), clue = s"at index $i") | ||
i += 1 | ||
end while | ||
end assertVecEquals | ||
|
||
test("Higher kinded matmul") { | ||
val mat1 = Matrix.fromRows( | ||
NArray( | ||
NArray(1L, 2L, 3L), | ||
NArray(4L, 5L, 6L) | ||
) | ||
) | ||
|
||
val mat2 = Matrix.fromRows( | ||
NArray( | ||
NArray(7L, 8L), | ||
NArray(9L, 10L), | ||
NArray(11L, 12L) | ||
) | ||
) | ||
|
||
val result = Matrix.fromRows( | ||
NArray( | ||
NArray(58L, 64L), | ||
NArray(139L, 154L) | ||
) | ||
) | ||
|
||
val mult = mat1 @@@ mat2 | ||
assertVecEquals(mult.raw, result.raw) | ||
|
||
} | ||
|
||
test("Spire matmul") { | ||
|
||
val mat1 = Matrix.fromRows[Complex[Double]]( | ||
NArray( | ||
NArray[Complex[Double]](Complex(1.0, -1.0), Complex(0.0, 2.0), Complex(-2.0, 1.0)), | ||
NArray[Complex[Double]](Complex(0.0, -3.0), Complex(3.0, -2.0), Complex(-1.0, -1.0)) | ||
) | ||
) | ||
|
||
val mat2 = Matrix.fromRows[Complex[Double]]( | ||
NArray( | ||
NArray[Complex[Double]](Complex(0.0, -2.0), Complex(1.0, -4.0)), | ||
NArray[Complex[Double]](Complex(-1.0, 3.0), Complex(2.0, -3.0)), | ||
NArray[Complex[Double]](Complex(-2.0, 1.0), Complex(-4.0, 1.0)) | ||
) | ||
) | ||
|
||
val result = Matrix.fromRows[Complex[Double]]( | ||
NArray( | ||
NArray[Complex[Double]](Complex(-5.0, -8.0), Complex(10.0, -7.0)), | ||
NArray[Complex[Double]](Complex(0.0, 12.0), Complex(-7.0, -13.0)) | ||
) | ||
) | ||
|
||
val mult: Matrix[Complex[Double]] = mat1 @@@ mat2 | ||
assertVecEquals(mult.raw, result.raw) | ||
} | ||
|
||
end MatrixExtensionSuite |