Skip to content

Commit

Permalink
Merge pull request #20 from future-architect/feature/add-support-time
Browse files Browse the repository at this point in the history
add time.Time parameter support
  • Loading branch information
ma91n authored Jun 10, 2022
2 parents 1d94616 + 4772ca4 commit cac4323
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 81 deletions.
50 changes: 28 additions & 22 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"fmt"
"os"
"testing"
"time"

"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"gotest.tools/v3/assert"
"gotest.tools/v3/assert/cmp"
)

type Person struct {
Expand All @@ -18,6 +21,8 @@ type Person struct {
Email string `db:"email"`
NullString sql.NullString `db:"null_string"`
NullInt sql.NullInt64 `db:"null_int"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at"`
}

func TestSelect(t *testing.T) {
Expand Down Expand Up @@ -66,9 +71,7 @@ func TestSelect(t *testing.T) {
t.Errorf("select: failed: %v", err)
}

if !match(people, expected) {
t.Errorf("\nexpected:\n%v\nbut got\n%v\n", expected, people)
}
assert.Check(t, cmp.DeepEqual(people, expected))

}

Expand Down Expand Up @@ -105,9 +108,7 @@ func TestUpdate(t *testing.T) {
Email: "[email protected]",
},
}
if !match(people, expected) {
t.Errorf("expected:\n%v\nbut got\n%v\n", expected, people)
}
assert.Check(t, cmp.DeepEqual(people, expected))
}

func TestInsertAndDelete(t *testing.T) {
Expand All @@ -126,14 +127,21 @@ func TestInsertAndDelete(t *testing.T) {
Email: "[email protected]",
NullString: sql.NullString{String: "value", Valid: true},
NullInt: sql.NullInt64{Int64: 11, Valid: false}, // NULL 登録
}
_, err := tw.Exec(ctx, `INSERT INTO persons (employee_no, dept_no, first_name, last_name, email, null_string, null_int) VALUES(/*EmpNo*/1, /*deptNo*/1, /*firstName*/"Tim", /*lastName*/"Cook", /*email*/"[email protected]", /*null_string*/'null', /*null_int*/1)`, &params)
CreatedAt: time.Date(2022, 6, 10, 17, 0, 0, 0, time.UTC),
UpdatedAt: sql.NullTime{Time: time.Date(2022, 6, 10, 18, 0, 0, 0, time.UTC), Valid: true},
}
_, err := tw.Exec(ctx, `
INSERT INTO persons
(employee_no, dept_no, first_name, last_name, email, null_string, null_int, created_at, updated_at)
VALUES
(/*EmpNo*/1, /*deptNo*/1, /*firstName*/"Tim", /*lastName*/"Cook", /*email*/"[email protected]", /*null_string*/'null', /*null_int*/1, /*created_at*/'2022-06-01 10:00:00', /*updated_at*/'2022-06-02 10:00:00')`,
&params)
if err != nil {
t.Fatalf("exec: failed: %v", err)
}

var people []Person
err = tw.Select(ctx, &people, `SELECT first_name, last_name, email, null_string, null_int FROM persons WHERE dept_no = /*deptNo*/0`, &params)
err = tw.Select(ctx, &people, `SELECT first_name, last_name, email, null_string, null_int, created_at, updated_at FROM persons WHERE dept_no = /*deptNo*/0`, &params)
if err != nil {
t.Fatalf("select: failed: %v", err)
}
Expand All @@ -145,11 +153,11 @@ func TestInsertAndDelete(t *testing.T) {
Email: "[email protected]",
NullString: sql.NullString{String: "value", Valid: true},
NullInt: sql.NullInt64{Int64: 0, Valid: false}, // NULL 確認
CreatedAt: time.Date(2022, 6, 10, 17, 0, 0, 0, time.UTC),
UpdatedAt: sql.NullTime{Time: time.Date(2022, 6, 10, 18, 0, 0, 0, time.UTC), Valid: true},
},
}
if !match(people, expected) {
t.Errorf("expected:\n%v\nbut got\n%v\n", expected, people)
}
assert.Check(t, cmp.DeepEqual(people, expected))

_, err = tw.Exec(ctx, `DELETE FROM persons WHERE employee_no = /*EmpNo*/2`, &params)
if err != nil {
Expand All @@ -163,9 +171,7 @@ func TestInsertAndDelete(t *testing.T) {
}

expected = []Person{}
if !match(people, expected) {
t.Errorf("expected:\n%v\nbut got\n%v\n", expected, people)
}
assert.Check(t, cmp.DeepEqual(people, expected))
}

func TestTxCommit(t *testing.T) {
Expand All @@ -179,8 +185,8 @@ func TestTxCommit(t *testing.T) {
// insert test data
const insertSQL = `
INSERT INTO persons
(employee_no, dept_no, first_name, last_name, email) VALUES
(11, 111, 'Clegg', 'George', '[email protected]')
(employee_no, dept_no, first_name, last_name, email, created_at) VALUES
(11, 111, 'Clegg', 'George', '[email protected]', CURRENT_TIMESTAMP)
;
`
if _, err := tw.Exec(ctx, insertSQL, nil); err != nil {
Expand Down Expand Up @@ -250,8 +256,8 @@ func TestTxRollback(t *testing.T) {
// insert test data
const insertSQL = `
INSERT INTO persons
(employee_no, dept_no, first_name, last_name, email) VALUES
(12, 121, 'Chmmg', 'Dudley', '[email protected]')
(employee_no, dept_no, first_name, last_name, email, created_at) VALUES
(12, 121, 'Chmmg', 'Dudley', '[email protected]', CURRENT_TIMESTAMP)
;
`
if _, err := tw.Exec(ctx, insertSQL, nil); err != nil {
Expand Down Expand Up @@ -321,9 +327,9 @@ func TestTxBlock(t *testing.T) {
// insert test data
const insertSQL = `
INSERT INTO persons
(employee_no, dept_no, first_name, last_name, email) VALUES
(13, 131, 'Darling', 'Wat', '[email protected]'),
(14, 141, 'Hallows', 'Jessie', '[email protected]')
(employee_no, dept_no, first_name, last_name, email, created_at) VALUES
(13, 131, 'Darling', 'Wat', '[email protected]', CURRENT_TIMESTAMP),
(14, 141, 'Hallows', 'Jessie', '[email protected]', CURRENT_TIMESTAMP)
;`
if _, err := tw.Exec(ctx, insertSQL, nil); err != nil {
t.Fatal(err)
Expand Down
120 changes: 67 additions & 53 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
"unicode"

"gitlab.com/osaki-lab/tagscanner/runtimescan"
Expand Down Expand Up @@ -184,8 +185,8 @@ func encode(dest map[string]interface{}, src interface{}) error {
return err
}

// tagscanner does not support sql.NullXXX type.
encodeSQLNullTyp(src, dest, tags)
// tagscanner does not support nest struct type.
encodeNestStructTyp(src, dest, tags)

return nil
}
Expand All @@ -200,66 +201,79 @@ func convertToMapStringAny(mp reflect.Value, dest map[string]interface{}) bool {
return true
}

func encodeSQLNullTyp(src interface{}, dest map[string]interface{}, tags []string) {
const targetPkg = "database/sql"
func encodeNestStructTyp(src interface{}, dest map[string]interface{}, tags []string) {
srcFieldTyps := reflect.ValueOf(src).Type().Elem()
srcFieldValues := reflect.ValueOf(src).Elem()
for i := 0; i < srcFieldTyps.NumField(); i++ {
srcFieldTyp := srcFieldTyps.Field(i)
if srcFieldTyp.Type.PkgPath() != targetPkg {
continue
}
tagValue := getTagValue(srcFieldTyp.Tag, tags)
if tagValue == "" {
continue
}
switch v := srcFieldValues.Field(i).Interface().(type) {
case sql.NullBool:
if v.Valid {
dest[tagValue] = v.Bool
} else {
dest[tagValue] = nil
}
case sql.NullByte:
// not support
continue
case sql.NullFloat64:
if v.Valid {
dest[tagValue] = v.Float64
} else {
dest[tagValue] = nil
}
case sql.NullInt16:
if v.Valid {
dest[tagValue] = v.Int16
} else {
dest[tagValue] = nil
}
case sql.NullInt32:
if v.Valid {
dest[tagValue] = v.Int32
} else {
dest[tagValue] = nil
}
case sql.NullInt64:
if v.Valid {
dest[tagValue] = v.Int64
} else {
dest[tagValue] = nil
}
case sql.NullString:
if v.Valid {
dest[tagValue] = v.String
} else {
dest[tagValue] = nil
}
case sql.NullTime:
if v.Valid {
dest[tagValue] = v.Time
} else {
dest[tagValue] = nil
}
srcFieldValue := srcFieldValues.Field(i)
switch srcFieldTyp.Type.PkgPath() {
case "database/sql":
encodeSQLNullTyp(srcFieldValue, dest, tagValue)
case "time":
encodeTimeTyp(srcFieldValue, dest, tagValue)
}
}
}

func encodeSQLNullTyp(srcFieldValue reflect.Value, dest map[string]interface{}, tagValue string) {
switch v := srcFieldValue.Interface().(type) {
case sql.NullBool:
if v.Valid {
dest[tagValue] = v.Bool
} else {
dest[tagValue] = nil
}
case sql.NullByte:
// not support
return
case sql.NullFloat64:
if v.Valid {
dest[tagValue] = v.Float64
} else {
dest[tagValue] = nil
}
case sql.NullInt16:
if v.Valid {
dest[tagValue] = v.Int16
} else {
dest[tagValue] = nil
}
case sql.NullInt32:
if v.Valid {
dest[tagValue] = v.Int32
} else {
dest[tagValue] = nil
}
case sql.NullInt64:
if v.Valid {
dest[tagValue] = v.Int64
} else {
dest[tagValue] = nil
}
case sql.NullString:
if v.Valid {
dest[tagValue] = v.String
} else {
dest[tagValue] = nil
}
case sql.NullTime:
if v.Valid {
dest[tagValue] = v.Time
} else {
dest[tagValue] = nil
}
}
}

func encodeTimeTyp(srcFieldValue reflect.Value, dest map[string]interface{}, tagValue string) {
switch v := srcFieldValue.Interface().(type) {
case time.Time:
dest[tagValue] = v
}
}

Expand Down
14 changes: 13 additions & 1 deletion eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type Info struct {
Table [][]interface{} `twowaysql:"table"`
NullString sql.NullString `twowaysql:"null_string"`
NullInt sql.NullInt64 `twowaysql:"null_int"`
CreatedAt time.Time `twowaysql:"created_at"`
UpdatedAt sql.NullTime `twowaysql:"updated_at"`
}

func TestEval(t *testing.T) {
Expand Down Expand Up @@ -1119,7 +1121,7 @@ func TestEvalWithMap(t *testing.T) {
}
}

func TestEval_SQLNullTyp(t *testing.T) {
func TestEval_NestStructTyp(t *testing.T) {
type SQLTypInfo struct {
NullBool sql.NullBool `db:"null_bool"`
NullFloat64 sql.NullFloat64 `db:"null_float_64"`
Expand All @@ -1128,6 +1130,7 @@ func TestEval_SQLNullTyp(t *testing.T) {
NullInt64 sql.NullInt64 `db:"null_int_64"`
NullString sql.NullString `db:"null_string"`
NullTime sql.NullTime `db:"null_time"`
Time time.Time `db:"time"`
}

tests := []struct {
Expand Down Expand Up @@ -1203,6 +1206,15 @@ func TestEval_SQLNullTyp(t *testing.T) {
wantQuery: `SELECT * FROM person WHERE value = ?/*null_time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind time.Time",
input: `SELECT * FROM person WHERE value = /*time*/'2022-01-01 10:00:00'`,
inputParams: SQLTypInfo{
Time: time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC),
},
wantQuery: `SELECT * FROM person WHERE value = ?/*time*/`,
wantParams: []interface{}{time.Date(2022, 7, 1, 12, 30, 30, 0, time.UTC)},
},
{
name: "bind initial",
input: `SELECT * FROM person WHERE value = /*null_string*/'hoge'`,
Expand Down
12 changes: 7 additions & 5 deletions postgres/init/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ CREATE TABLE persons (
last_name VARCHAR(100),
email VARCHAR(100),
null_string VARCHAR(100),
null_int INT
null_int INT,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone
);

INSERT INTO persons(employee_no, dept_no, first_name, last_name, email) VALUES
(1, 10, 'Evan', 'MacMans', '[email protected]'),
(2, 11, 'Malvina', 'FitzSimons', '[email protected]'),
(3, 12, 'Jimmie', 'Bruce', '[email protected]')
INSERT INTO persons(employee_no, dept_no, first_name, last_name, email, created_at) VALUES
(1, 10, 'Evan', 'MacMans', '[email protected]', CURRENT_TIMESTAMP),
(2, 11, 'Malvina', 'FitzSimons', '[email protected]', CURRENT_TIMESTAMP),
(3, 12, 'Jimmie', 'Bruce', '[email protected]', CURRENT_TIMESTAMP)
;

0 comments on commit cac4323

Please sign in to comment.