Skip to content

Commit 3463a8d

Browse files
committed
Parse DATETIME and DATE to time.Time
Issue #9
1 parent dc34e78 commit 3463a8d

File tree

6 files changed

+116
-90
lines changed

6 files changed

+116
-90
lines changed

connection.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"errors"
1515
"net"
1616
"strings"
17+
"time"
1718
)
1819

1920
type mysqlConn struct {
@@ -36,6 +37,7 @@ type config struct {
3637
addr string
3738
dbname string
3839
params map[string]string
40+
loc *time.Location
3941
}
4042

4143
// Handles parameters set in DSN
@@ -52,8 +54,8 @@ func (mc *mysqlConn) handleParams() (err error) {
5254
}
5355
}
5456

55-
// Timeout - already handled on connecting
56-
case "timeout":
57+
// Timeout or Location- already handled on connecting
58+
case "timeout", "loc":
5759
continue
5860

5961
// TLS-Encryption

driver.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
2525

2626
// New mysqlConn
2727
mc := new(mysqlConn)
28-
mc.cfg = parseDSN(dsn)
28+
mc.cfg, err = parseDSN(dsn)
29+
if err != nil {
30+
return nil, err
31+
}
2932

3033
// Connect to Server
3134
if _, ok := mc.cfg.params["timeout"]; ok { // with timeout

driver_test.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"sync"
99
"testing"
10+
"time"
1011
)
1112

1213
var (
@@ -351,26 +352,45 @@ func TestDateTime(t *testing.T) {
351352

352353
types := [...]string{"DATE", "DATETIME"}
353354
in := [...]string{"2012-06-14", "2011-11-20 21:27:37"}
354-
var out string
355+
var tIn, out time.Time
355356
var rows *sql.Rows
356357

357358
for i, v := range types {
358-
mustExec(t, db, "CREATE TABLE test (value "+v+") CHARACTER SET utf8 COLLATE utf8_unicode_ci")
359+
tIn, err = time.Parse(timeFormat[:len(in[i])], in[i])
360+
if err != nil {
361+
t.Error(err.Error())
362+
}
359363

360-
mustExec(t, db, ("INSERT INTO test VALUES (?)"), in[i])
364+
mustExec(t, db, "CREATE TABLE test (value "+v+")")
365+
mustExec(t, db, ("INSERT INTO test VALUES (?)"), tIn)
361366

362367
rows = mustQuery(t, db, ("SELECT value FROM test"))
363368
if rows.Next() {
364369
rows.Scan(&out)
365-
if in[i] != out {
366-
t.Errorf("%s: %s != %s", v, in[i], out)
370+
if tIn.String() != out.String() {
371+
t.Errorf("%s: %s != %s", v, tIn.String(), out.String())
372+
}
373+
if out.IsZero() {
374+
t.Errorf("%s: Unexpected Zero Time", v)
367375
}
368376
} else {
369377
t.Errorf("%s: no data", v)
370378
}
371379

372380
mustExec(t, db, "DROP TABLE IF EXISTS test")
373381
}
382+
383+
// Zero Time
384+
mustExec(t, db, "CREATE TABLE test (value DATETIME)")
385+
mustExec(t, db, ("INSERT INTO test VALUES (?)"), "0000-00-00 00:00:00")
386+
err = db.QueryRow("SELECT value FROM test").Scan(&out)
387+
if err != nil {
388+
t.Error(err.Error())
389+
}
390+
if !out.IsZero() {
391+
t.Error("Default Time not zero")
392+
}
393+
mustExec(t, db, "DROP TABLE IF EXISTS test")
374394
}
375395

376396
func TestNULL(t *testing.T) {

packets.go

Lines changed: 58 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,20 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
491491
pos += n
492492
if err == nil {
493493
if !isNull {
494-
continue
494+
switch rows.columns[i].fieldType {
495+
case fieldTypeTimestamp, fieldTypeDateTime:
496+
dest[i], err = time.Parse(timeFormat, string(dest[i].([]byte)))
497+
if err == nil {
498+
continue
499+
}
500+
case fieldTypeDate, fieldTypeNewDate:
501+
dest[i], err = time.Parse(timeFormat[:10], string(dest[i].([]byte)))
502+
if err == nil {
503+
continue
504+
}
505+
default:
506+
continue
507+
}
495508
} else {
496509
dest[i] = nil
497510
continue
@@ -802,8 +815,9 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
802815
}
803816
return // err
804817

805-
// Date YYYY-MM-DD
806-
case fieldTypeDate, fieldTypeNewDate:
818+
// Date YYYY-MM-DD, Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
819+
case fieldTypeDate, fieldTypeNewDate,
820+
fieldTypeTimestamp, fieldTypeDateTime:
807821
var num uint64
808822
var isNull bool
809823
num, isNull, n = readLengthEncodedInteger(data[pos:])
@@ -815,16 +829,50 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
815829
dest[i] = nil
816830
continue
817831
} else {
818-
dest[i] = []byte("0000-00-00")
832+
dest[i] = time.Time{}
819833
continue
820834
}
821-
} else {
822-
dest[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
823-
binary.LittleEndian.Uint16(data[pos:pos+2]),
824-
data[pos+2],
825-
data[pos+3]))
826-
pos += int(num)
835+
}
836+
837+
switch num {
838+
case 4:
839+
dest[i] = time.Date(
840+
int(binary.LittleEndian.Uint16(data[pos:pos+2])), // year
841+
time.Month(data[pos+2]), // month
842+
int(data[pos+3]), // day
843+
0, 0, 0, 0,
844+
rc.mc.cfg.loc,
845+
)
846+
pos += 4
827847
continue
848+
case 7:
849+
dest[i] = time.Date(
850+
int(binary.LittleEndian.Uint16(data[pos:pos+2])), // year
851+
time.Month(data[pos+2]), // month
852+
int(data[pos+3]), // day
853+
int(data[pos+4]), // hour
854+
int(data[pos+5]), // minutes
855+
int(data[pos+6]), // seconds
856+
0,
857+
rc.mc.cfg.loc,
858+
)
859+
pos += 7
860+
continue
861+
case 11:
862+
dest[i] = time.Date(
863+
int(binary.LittleEndian.Uint16(data[pos:pos+2])), // year
864+
time.Month(data[pos+2]), // month
865+
int(data[pos+3]), // day
866+
int(data[pos+4]), // hour
867+
int(data[pos+5]), // minutes
868+
int(data[pos+6]), // seconds
869+
int(binary.LittleEndian.Uint32(data[pos+7:pos+11]))*1000, // nanoseconds
870+
rc.mc.cfg.loc,
871+
)
872+
pos += 11
873+
continue
874+
default:
875+
return fmt.Errorf("Invalid DATETIME-packet length %d", num)
828876
}
829877

830878
// Time [-][H]HH:MM:SS[.fractal]
@@ -876,63 +924,6 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
876924
return fmt.Errorf("Invalid TIME-packet length %d", num)
877925
}
878926

879-
// Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
880-
case fieldTypeTimestamp, fieldTypeDateTime:
881-
var num uint64
882-
var isNull bool
883-
num, isNull, n = readLengthEncodedInteger(data[pos:])
884-
885-
pos += n
886-
887-
if num == 0 {
888-
if isNull {
889-
dest[i] = nil
890-
continue
891-
} else {
892-
dest[i] = []byte("0000-00-00 00:00:00")
893-
continue
894-
}
895-
}
896-
897-
switch num {
898-
case 4:
899-
dest[i] = []byte(fmt.Sprintf(
900-
"%04d-%02d-%02d 00:00:00",
901-
binary.LittleEndian.Uint16(data[pos:pos+2]),
902-
data[pos+2],
903-
data[pos+3],
904-
))
905-
pos += 4
906-
continue
907-
case 7:
908-
dest[i] = []byte(fmt.Sprintf(
909-
"%04d-%02d-%02d %02d:%02d:%02d",
910-
binary.LittleEndian.Uint16(data[pos:pos+2]),
911-
data[pos+2],
912-
data[pos+3],
913-
data[pos+4],
914-
data[pos+5],
915-
data[pos+6],
916-
))
917-
pos += 7
918-
continue
919-
case 11:
920-
dest[i] = []byte(fmt.Sprintf(
921-
"%04d-%02d-%02d %02d:%02d:%02d.%06d",
922-
binary.LittleEndian.Uint16(data[pos:pos+2]),
923-
data[pos+2],
924-
data[pos+3],
925-
data[pos+4],
926-
data[pos+5],
927-
data[pos+6],
928-
binary.LittleEndian.Uint32(data[pos+7:pos+11]),
929-
))
930-
pos += 11
931-
continue
932-
default:
933-
return fmt.Errorf("Invalid DATETIME-packet length %d", num)
934-
}
935-
936927
// Please report if this happens!
937928
default:
938929
return fmt.Errorf("Unknown FieldType %d", rc.columns[i].fieldType)

utils.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"regexp"
1818
"strings"
19+
"time"
1920
)
2021

2122
// Logger
@@ -36,8 +37,8 @@ func init() {
3637
// Data Source Name Parser
3738
var dsnPattern *regexp.Regexp
3839

39-
func parseDSN(dsn string) *config {
40-
cfg := new(config)
40+
func parseDSN(dsn string) (cfg *config, err error) {
41+
cfg = new(config)
4142
cfg.params = make(map[string]string)
4243

4344
matches := dsnPattern.FindStringSubmatch(dsn)
@@ -76,7 +77,9 @@ func parseDSN(dsn string) *config {
7677
cfg.addr = "127.0.0.1:3306"
7778
}
7879

79-
return cfg
80+
cfg.loc, err = time.LoadLocation(cfg.params["loc"])
81+
82+
return
8083
}
8184

8285
// Encrypt password using 4.1+ method

utils_test.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,39 @@ package mysql
1212
import (
1313
"fmt"
1414
"testing"
15+
"time"
1516
)
1617

1718
var testDSNs = []struct {
1819
in string
1920
out string
21+
loc *time.Location
2022
}{
21-
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value]}"},
22-
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8]}"},
23-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8]}"},
24-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8]}"},
25-
{"user:password@/dbname", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[]}"},
26-
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[]}"},
27-
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[]}"},
28-
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[]}"},
29-
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[]}"},
23+
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
24+
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
25+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
26+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
27+
{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
28+
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
29+
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
30+
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
31+
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
3032
}
3133

3234
func TestDSNParser(t *testing.T) {
3335
var cfg *config
36+
var err error
3437
var res string
3538

3639
for i, tst := range testDSNs {
37-
cfg = parseDSN(tst.in)
40+
cfg, err = parseDSN(tst.in)
41+
if err != nil {
42+
t.Error(err.Error())
43+
}
44+
3845
res = fmt.Sprintf("%+v", cfg)
39-
if res != tst.out {
40-
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, tst.out)
46+
if res != fmt.Sprintf(tst.out, tst.loc) {
47+
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
4148
}
4249
}
4350
}

0 commit comments

Comments
 (0)