Skip to content

Commit

Permalink
Merge pull request #13 from lemastero/id-function
Browse files Browse the repository at this point in the history
Handle identity function
  • Loading branch information
lemastero authored May 4, 2024
2 parents 245c5ce + 9f31cec commit bed63b6
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 41 deletions.
26 changes: 20 additions & 6 deletions examples/adts.agda
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
module examples.adts where

-- simple sum type no arguments - sealed trait + case objects
-- simple product type no arguments - sealed trait + case objects

data Rgb : Set where
Red : Rgb
Green : Rgb
Blue : Rgb
{-# COMPILE AGDA2SCALA Rgb #-}

-- simple sum type with arguments - sealed trait + case class
data Bool : Set where
True : Bool
False : Bool
{-# COMPILE AGDA2SCALA Bool #-}

-- trivial function with single argument

idRgb : Rgb -> Rgb
idRgb x = x
{-# COMPILE AGDA2SCALA idRgb #-}

-- simple sum type - case class

data Color : Set where
Light : Rgb -> Color
Dark : Rgb -> Color
{-# COMPILE AGDA2SCALA Color #-}
record RgbPair : Set where
constructor mkRgbPair
field
fst : Rgb
snd : Bool
{-# COMPILE AGDA2SCALA RgbPair #-}
13 changes: 9 additions & 4 deletions examples/adts.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package adts
object adts {

sealed trait Rgb
case object Red extends Rgb
case object Green extends Rgb
case object Blue extends Rgb

sealed trait Color
case object Light extends Color
case object Dark extends Color
sealed trait Bool
case object True extends Bool
case object False extends Bool

def idRgb(x: Rgb): Rgb = x

final case class RgbPair(snd: Bool, fst: Rgb)
}
100 changes: 92 additions & 8 deletions src/Agda/Compiler/Scala/AgdaToScalaExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,113 @@ module Agda.Compiler.Scala.AgdaToScalaExpr (
import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..))
import Agda.Syntax.Abstract.Name ( QName )
import Agda.Syntax.Common.Pretty ( prettyShow )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..) )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), NamedName, WithOrigin(..), Ranged(..) )
import Agda.Syntax.Internal (
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..),
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), Dom'(..), unDom, PatternInfo(..), Pattern'(..),
qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) )
import Agda.TypeChecking.Monad.Base ( Definition(..) )
import Agda.TypeChecking.Monad
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )
import Agda.TypeChecking.Telescope ( teleNamedArgs, teleArgs, teleArgNames )

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )
import Agda.Syntax.Common.Pretty ( prettyShow )

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaType, FunBody, ScalaExpr(..), SeVar(..) )

compileDefn :: QName -> Defn -> ScalaExpr
compileDefn defName theDef = case theDef of
Datatype{dataCons = dataCons} ->
compileDataType defName dataCons
Function{funCompiled = funDef, funClauses = fc} ->
Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef)
compileFunction defName funDef fc
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef)
compileRecord defName recFields recTel
other ->
Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef)

compileRecord :: QName -> [Dom QName] -> Telescope -> ScalaExpr
compileRecord defName recFields recTel = SeProd (fromQName defName) (foldl varsFromTelescope [] recTel)

varsFromTelescope :: [SeVar] -> Dom Type -> [SeVar]
varsFromTelescope xs dt = SeVar (nameFromDom dt) (fromDom dt) : xs

compileDataType :: QName -> [QName] -> ScalaExpr
compileDataType defName fields = SeAdt (showName defName) (map showName fields)
compileDataType defName fields = SeSum (fromQName defName) (map fromQName fields)

compileFunction :: QName
-> Maybe CompiledClauses
-> [Clause]
-> ScalaExpr
compileFunction defName funDef fc =
SeFun
(fromQName defName)
[SeVar (compileFunctionArgument fc) (compileFunctionArgType fc)] -- TODO many function arguments
(compileFunctionResultType fc)
(compileFunctionBody funDef)

compileFunctionArgument :: [Clause] -> ScalaName
compileFunctionArgument [] = ""
compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc))))

Check warning on line 54 in src/Agda/Compiler/Scala/AgdaToScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

In the use of ‘head’
compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs

compileFunctionArgType :: [Clause] -> ScalaType
compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct
compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs)

fromTelescope :: Telescope -> ScalaName -- TODO PP probably parent should be different, use fold on telescope above
fromTelescope tel = case tel of
ExtendTel a _ -> fromDom a
other -> error ("unhandled fromType" ++ show other)

nameFromDom :: Dom Type -> ScalaName
nameFromDom dt = case (domName dt) of
Nothing -> error ("nameFromDom" ++ show dt)
Just a -> namedNameToStr a

namedNameToStr :: NamedName -> ScalaName
namedNameToStr n = rangedThing (woThing n)

fromDom :: Dom Type -> ScalaName
fromDom x = fromType (unDom x)

compileFunctionResultType :: [Clause] -> ScalaType
compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct
compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other)

fromMaybeType :: Maybe (Arg Type) -> ScalaName
fromMaybeType (Just argType) = fromArgType argType
fromMaybeType other = error ("unhandled fromMaybeType" ++ show other)

fromArgType :: Arg Type -> ScalaName
fromArgType arg = fromType (unArg arg)

fromType :: Type -> ScalaName
fromType t = case t of
a@(El _ ue) -> fromTerm ue
other -> error ("unhandled fromType" ++ show other)

Check warning on line 91 in src/Agda/Compiler/Scala/AgdaToScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

Pattern match is redundant

fromTerm :: Term -> ScalaName
fromTerm t = case t of
Def qname el -> fromQName qname
other -> error ("unhandled fromTerm" ++ show other)

fromDeBruijnPattern :: DeBruijnPattern -> ScalaName
fromDeBruijnPattern d = case d of
VarP a b -> (dbPatVarName b)
a@(ConP x y z) -> show a
other -> error ("unhandled fromDeBruijnPattern" ++ show other)

compileFunctionBody :: Maybe CompiledClauses -> FunBody
compileFunctionBody (Just funDef) = fromCompiledClauses funDef
compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef)

fromCompiledClauses :: CompiledClauses -> FunBody
fromCompiledClauses cc = case cc of
(Done (x:xs) term) -> fromArgName x
other -> error ("unhandled fromCompiledClauses " ++ show other)

fromArgName :: Arg ArgName -> FunBody
fromArgName = unArg

showName :: QName -> ScalaName
showName = prettyShow . qnameName
fromQName :: QName -> ScalaName
fromQName = prettyShow . qnameName
42 changes: 31 additions & 11 deletions src/Agda/Compiler/Scala/PrintScalaExpr.hs
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
{-# LANGUAGE OverloadedStrings #-}

module Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr
, printCaseObject
, printSealedTrait
, printPackage
, printCaseClass
, combineLines
) where

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )
import Data.List ( intercalate )
import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..), SeVar(..))

printScalaExpr :: ScalaExpr -> String
printScalaExpr def = case def of
(SePackage pName defs) ->
(printPackage pName) <> defsSeparator
<> (
(printPackage pName) <> exprSeparator -- TODO this should be package + object
<> bracket (
blankLine -- between package declaration and first definition
<> combineLines (map printScalaExpr defs)
)
<> blankLine -- EOF
(SeAdt adtName adtCases) ->
(SeSum adtName adtCases) ->
(printSealedTrait adtName)
<> defsSeparator
<> unlines (map (printCaseObject adtName) adtCases)
(Unhandled name payload) -> "" -- for development comment out this and uncomment below
-- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
-- other -> "unsupported printScalaExpr " ++ (show other)
<> combineLines (map (printCaseObject adtName) adtCases)
<> defsSeparator
(SeFun fName args resType funBody) ->
"def" <> exprSeparator <> fName
<> "(" <> combineLines (map printVar args) <> ")"
<> ":" <> exprSeparator <> resType <> exprSeparator
<> "=" <> exprSeparator <> funBody
<> defsSeparator
(SeProd name args) -> printCaseClass name args
(Unhandled "" payload) -> ""
(Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
other -> "unsupported printScalaExpr " ++ (show other)

Check warning on line 35 in src/Agda/Compiler/Scala/PrintScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

Pattern match is redundant

printCaseClass :: ScalaName -> [SeVar] -> String
printCaseClass name args = "final case class" <> exprSeparator <> name <> "(" <> (printExpr args) <> ")"

printVar :: SeVar -> String
printVar (SeVar sName sType) = sName <> ":" <> exprSeparator <> sType

printExpr :: [SeVar] -> String
printExpr names = combineThem (map printVar names)

combineThem :: [String] -> String
combineThem xs = intercalate ", " xs

printSealedTrait :: ScalaName -> String
printSealedTrait adtName = "sealed trait" <> exprSeparator <> adtName
Expand All @@ -34,7 +54,7 @@ printCaseObject superName caseName =
"case object" <> exprSeparator <> caseName <> exprSeparator <> "extends" <> exprSeparator <> superName

printPackage :: ScalaName -> String
printPackage pName = "package" <> exprSeparator <> pName
printPackage pName = "object" <> exprSeparator <> pName

bracket :: String -> String
bracket str = "{\n" <> str <> "\n}"
Expand Down
14 changes: 12 additions & 2 deletions src/Agda/Compiler/Scala/ScalaExpr.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
module Agda.Compiler.Scala.ScalaExpr (
ScalaName,
ScalaType,
ScalaExpr(..),
SeVar(..),
FunBody,
unHandled
) where

type ScalaName = String
type FunBody = String -- this should be some lambda expression
type ScalaType = String

{- Represent Scala language extracted from Agda compiler representation -}
data SeVar = SeVar ScalaName ScalaType
deriving ( Show )

{- Represent Scala language extracted from internal Agda compiler representation -}
data ScalaExpr
= SePackage ScalaName [ScalaExpr]
| SeAdt ScalaName [ScalaName]
| SeSum ScalaName [ScalaName]
| SeFun ScalaName [SeVar] ScalaType FunBody
| SeProd ScalaName [SeVar]
| Unhandled ScalaName String
deriving ( Show )

Expand Down
29 changes: 19 additions & 10 deletions test/PrintScalaExprTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import Agda.Compiler.Scala.PrintScalaExpr (
, printCaseObject
, printPackage

Check warning on line 8 in test/PrintScalaExprTest.hs

View workflow job for this annotation

GitHub Actions / agda2scala

The import of ‘printPackage’
, combineLines
, printCaseClass
)
import Agda.Compiler.Scala.ScalaExpr ( ScalaExpr(..) )
import Agda.Compiler.Scala.ScalaExpr ( ScalaExpr(..), SeVar(..) )

testPrintCaseObject :: Test
testPrintCaseObject = TestCase
Expand All @@ -22,11 +23,11 @@ testPrintSealedTrait = TestCase
"sealed trait Color"
(printSealedTrait "Color"))

testPrintPackage :: Test
testPrintPackage = TestCase
(assertEqual "printPackage"
"package adts"
(printPackage "adts"))
--testPrintPackage :: Test
--testPrintPackage = TestCase
-- (assertEqual "printPackage"
-- "package adts"
-- (printPackage "adts"))

testCombineLines :: Test
testCombineLines = TestCase
Expand All @@ -37,19 +38,27 @@ testCombineLines = TestCase
testPrintScalaExpr :: Test
testPrintScalaExpr = TestCase
(assertEqual "printScalaExpr" (printScalaExpr $ SePackage "adts" moduleContent)
"package adts\n\nsealed trait Rgb\ncase object Red extends Rgb\ncase object Green extends Rgb\ncase object Blue extends Rgb\n\nsealed trait Color\ncase object Light extends Color\ncase object Dark extends Color\n"
"object adts {\n\nsealed trait Rgb\ncase object Red extends Rgb\ncase object Green extends Rgb\ncase object Blue extends Rgb\n\nsealed trait Color\ncase object Light extends Color\ncase object Dark extends Color\n}\n"
)
where
moduleContent = [rgbAdt, blank, blank, blank, colorAdt, blank, blank]
rgbAdt = SeAdt "Rgb" ["Red","Green","Blue"]
colorAdt = SeAdt "Color" ["Light","Dark"]
rgbAdt = SeSum "Rgb" ["Red","Green","Blue"]
colorAdt = SeSum "Color" ["Light","Dark"]
blank = Unhandled "" ""

testPrintCaseClass :: Test
testPrintCaseClass = TestCase
(assertEqual "printCaseClass"
"final case class RgbPair(snd: Bool, fst: Rgb)"
(printCaseClass "RgbPair" [SeVar "snd" "Bool", SeVar "fst" "Rgb"]))


printScalaTests :: Test
printScalaTests = TestList [
TestLabel "printCaseObject" testPrintCaseObject
, TestLabel "printSealedTrait" testPrintSealedTrait
, TestLabel "printPackage" testPrintPackage
-- , TestLabel "printPackage" testPrintPackage
, TestLabel "combineLines" testCombineLines
, TestLabel "printCaseClass" testPrintCaseClass
, TestLabel "printScalaExpr" testPrintScalaExpr
]

0 comments on commit bed63b6

Please sign in to comment.