Skip to content

Commit

Permalink
sync-diff-inspector: support base64 encoded password (#688)
Browse files Browse the repository at this point in the history
close #687
  • Loading branch information
lichunzhu authored Nov 28, 2022
1 parent f7e6507 commit 7245a6e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 29 deletions.
23 changes: 16 additions & 7 deletions sync_diff_inspector/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"database/sql"
"encoding/json"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -171,19 +173,26 @@ func (d *DataSource) RegisterTLS() error {
return errors.Trace(err)
}

func (d *DataSource) GetDSN() (dbDSN string) {
func (d *DataSource) ToDriverConfig() *mysql.Config {
cfg := mysql.NewConfig()
cfg.Params = make(map[string]string)

cfg.User = d.User
cfg.Passwd = d.Password.Plain()
cfg.Net = "tcp"
cfg.Addr = net.JoinHostPort(d.Host, strconv.Itoa(d.Port))
cfg.Params["charset"] = "utf8mb4"
cfg.InterpolateParams = true
cfg.Params["time_zone"] = fmt.Sprintf("'%s'", UnifiedTimeZone)
if len(d.Snapshot) > 0 && !d.IsAutoSnapshot() {
log.Info("create connection with snapshot", zap.String("snapshot", d.Snapshot))
dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&interpolateParams=true&time_zone=%%27%s%%27&tidb_snapshot=%s", d.User, d.Password.Plain(), d.Host, d.Port, url.QueryEscape(UnifiedTimeZone), d.Snapshot)
} else {
dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&interpolateParams=true&time_zone=%%27%s%%27", d.User, d.Password.Plain(), d.Host, d.Port, url.QueryEscape(UnifiedTimeZone))
cfg.Params["tidb_snapshot"] = d.Snapshot
}

if d.Security != nil && len(d.Security.TLSName) > 0 {
dbDSN += "&tls=" + d.Security.TLSName
cfg.TLSConfig = d.Security.TLSName
}

return dbDSN
return cfg
}

type TaskConfig struct {
Expand Down
63 changes: 46 additions & 17 deletions sync_diff_inspector/source/common/conn.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,60 @@ package common

import (
"database/sql"
"regexp"
"encoding/base64"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
tmysql "github.com/pingcap/tidb/parser/mysql"
)

// CreateDB creates sql.DB used for select data
func CreateDB(dsn string, num int) (db *sql.DB, err error) {
db, err = sql.Open("mysql", dsn)
func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) {
pwd := val.(string)
if cfg.Passwd != pwd {
failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"})
}
failpoint.Return(nil, nil)
})
c, err := mysql.NewConnector(cfg)
if err != nil {
return nil, errors.Trace(err)
}
db := sql.OpenDB(c)
if err = db.Ping(); err != nil {
_ = db.Close()
return nil, errors.Trace(err)
}
return db, nil
}

err = db.Ping()
if err != nil {
reg, regErr := regexp.Compile(":.*@tcp")
if reg == nil || regErr != nil {
return nil, errors.Errorf("create db connections (failed to replace password for dsn) error %v", regErr)
// ConnectMySQL creates sql.DB used for select data
func ConnectMySQL(cfg *mysql.Config, num int) (db *sql.DB, err error) {
defer func() {
if err == nil && db != nil {
// SetMaxOpenConns and SetMaxIdleConns for connection to avoid error like
// `dial tcp 10.26.2.1:3306: connect: cannot assign requested address`
db.SetMaxOpenConns(num)
db.SetMaxIdleConns(num)
}
return nil, errors.Errorf("create db connections %s error %v", reg.ReplaceAllString(dsn, ":?@tcp"), err)
}()
// Try plain password first.
db, firstErr := tryConnectMySQL(cfg)
if firstErr == nil {
return db, nil
}

// SetMaxOpenConns and SetMaxIdleConns for connection to avoid error like
// `dial tcp 10.26.2.1:3306: connect: cannot assign requested address`
db.SetMaxOpenConns(num)
db.SetMaxIdleConns(num)

return db, nil
// If access is denied and password is encoded by base64, try the decoded string as well.
if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied {
// If password is encoded by base64, try the decoded string as well.
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
cfg.Passwd = string(password)
db2, err := tryConnectMySQL(cfg)
if err == nil {
return db2, nil
}
}
}
// If we can't connect successfully, return the first error.
return nil, errors.Trace(firstErr)
}
48 changes: 48 additions & 0 deletions sync_diff_inspector/source/common/conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"encoding/base64"
"fmt"
"testing"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb-tools/sync_diff_inspector/config"
"github.com/pingcap/tidb-tools/sync_diff_inspector/utils"
"github.com/stretchr/testify/require"
)

func TestConnect(t *testing.T) {
plainPsw := "dQAUoDiyb1ucWZk7"

require.NoError(t, failpoint.Enable(
"github.com/pingcap/tidb-tools/sync_diff_inspector/source/common/MustMySQLPassword",
fmt.Sprintf("return(\"%s\")", plainPsw)))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb-tools/sync_diff_inspector/source/common/MustMySQLPassword"))
}()

dataSource := &config.DataSource{
Host: "127.0.0.1",
Port: 4000,
User: "root",
Password: utils.SecretString(plainPsw),
}
_, err := ConnectMySQL(dataSource.ToDriverConfig(), 2)
require.NoError(t, err)
dataSource.Password = utils.SecretString(base64.StdEncoding.EncodeToString([]byte(plainPsw)))
_, err = ConnectMySQL(dataSource.ToDriverConfig(), 2)
require.NoError(t, err)
}
11 changes: 6 additions & 5 deletions sync_diff_inspector/source/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"
"time"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/tidb-tools/pkg/dbutil"
Expand Down Expand Up @@ -233,8 +234,8 @@ func buildSourceFromCfg(ctx context.Context, tableDiffs []*common.TableDiff, con
return NewMySQLSources(ctx, tableDiffs, dbs, connCount, f)
}

func getAutoSnapshotPosition(dsn string) (string, string, error) {
tmpConn, err := common.CreateDB(dsn, 2)
func getAutoSnapshotPosition(cfg *mysql.Config) (string, string, error) {
tmpConn, err := common.ConnectMySQL(cfg, 2)
if err != nil {
return "", "", errors.Annotatef(err, "connecting to auto-position tidb_snapshot failed")
}
Expand All @@ -257,7 +258,7 @@ func initDBConn(ctx context.Context, cfg *config.Config) error {
if !cfg.Task.SourceInstances[0].IsAutoSnapshot() {
return errors.Errorf("'auto' snapshot should be set on both target and source")
}
primaryTs, secondaryTs, err := getAutoSnapshotPosition(cfg.Task.TargetInstance.GetDSN())
primaryTs, secondaryTs, err := getAutoSnapshotPosition(cfg.Task.TargetInstance.ToDriverConfig())
if err != nil {
return err
}
Expand All @@ -266,7 +267,7 @@ func initDBConn(ctx context.Context, cfg *config.Config) error {
}
// we had `cfg.SplitThreadCount` producers and `cfg.CheckThreadCount` consumer to use db connections maybe and `cfg.CheckThreadCount` splitter to split buckets.
// so the connection count need to be cfg.SplitThreadCount + cfg.CheckThreadCount + cfg.CheckThreadCount.
targetConn, err := common.CreateDB(cfg.Task.TargetInstance.GetDSN(), cfg.SplitThreadCount+2*cfg.CheckThreadCount)
targetConn, err := common.ConnectMySQL(cfg.Task.TargetInstance.ToDriverConfig(), cfg.SplitThreadCount+2*cfg.CheckThreadCount)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -280,7 +281,7 @@ func initDBConn(ctx context.Context, cfg *config.Config) error {
return errors.Errorf("'auto' snapshot should be set on both target and source")
}
// connect source db with target db time_zone
conn, err := common.CreateDB(source.GetDSN(), cfg.SplitThreadCount+2*cfg.CheckThreadCount)
conn, err := common.ConnectMySQL(source.ToDriverConfig(), cfg.SplitThreadCount+2*cfg.CheckThreadCount)
if err != nil {
return errors.Trace(err)
}
Expand Down

0 comments on commit 7245a6e

Please sign in to comment.