Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bulk copy for decimal types #234

Merged
merged 7 commits into from
Sep 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 80 additions & 39 deletions bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package mssql

import (
"bytes"
_ "database/sql/driver"
"encoding/binary"
"fmt"
"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
"math"
"reflect"
"strings"
"time"

"strconv"

"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
)

type MssqlBulk struct {
Expand Down Expand Up @@ -140,7 +142,7 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
// AddRow immediately writes the row to the destination table.
// The arguments are the row values in the order they were specified.
func (b *MssqlBulk) AddRow(row []interface{}) (err error) {
if b.headerSent == false {
if !b.headerSent {
err = b.sendBulkCommand()
if err != nil {
return
Expand Down Expand Up @@ -372,7 +374,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
case int64:
floatvalue = float64(val)
default:
err = fmt.Errorf("mssql: invalid type for float column", val)
err = fmt.Errorf("mssql: invalid type for float column: %s", val)
return
}

Expand All @@ -391,7 +393,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
case []byte:
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for nvarchar column", val)
err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
return
}
res.ti.Size = len(res.buffer)
Expand All @@ -403,14 +405,14 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
case []byte:
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for varchar column", val)
err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
return
}
res.ti.Size = len(res.buffer)

case typeBit, typeBitN:
if reflect.TypeOf(val).Kind() != reflect.Bool {
err = fmt.Errorf("mssql: invalid type for bit column", val)
err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
return
}
res.ti.TypeId = typeBitN
Expand Down Expand Up @@ -460,7 +462,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
res.buffer = buf

default:
err = fmt.Errorf("mssql: invalid type for datetime2 column", val)
err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
return
}
case typeDateN:
Expand All @@ -474,7 +476,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
res.buffer[1] = byte(days >> 8)
res.buffer[2] = byte(days >> 16)
default:
err = fmt.Errorf("mssql: invalid type for date column", val)
err = fmt.Errorf("mssql: invalid type for date column: %s", val)
return
}
case typeDateTime, typeDateTimeN, typeDateTim4:
Expand Down Expand Up @@ -511,50 +513,89 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
}

default:
err = fmt.Errorf("mssql: invalid type for datetime column", val)
err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
}

/*
case typeDecimal, typeDecimalN:
switch val := val.(type) {
case float64:
dec, err := Float64ToDecimal(val)
if err != nil {
return res, err
}
dec.scale = col.ti.Scale
dec.prec = col.ti.Prec
//res.buffer = make([]byte, 3)
res.buffer = dec.Bytes()
res.ti.Size = len(res.buffer)
default:
err = fmt.Errorf("mssql: invalid type for decimal column", val)
return
// case typeMoney, typeMoney4, typeMoneyN:
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
var value float64
switch v := val.(type) {
case int:
value = float64(v)
case int8:
value = float64(v)
case int16:
value = float64(v)
case int32:
value = float64(v)
case int64:
value = float64(v)
case float32:
value = float64(v)
case float64:
value = v
case string:
if value, err = strconv.ParseFloat(v, 64); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
}
case typeMoney, typeMoney4, typeMoneyN:
if col.ti.Size == 4 {
res.ti.Size = 4
res.buffer = make([]byte, 4)
} else if col.ti.Size == 8 {
res.ti.Size = 8
res.buffer = make([]byte, 8)
default:
return res, fmt.Errorf("unknown value for decimal: %#v", v)
}

} else {
err = fmt.Errorf("mssql: invalid size of money column")
}
*/
perc := col.ti.Prec
scale := col.ti.Scale
var dec Decimal
dec, err = Float64ToDecimalScale(value, scale)
if err != nil {
return res, err
}
dec.prec = perc

var length byte
switch {
case perc <= 9:
length = 4
case perc <= 19:
length = 8
case perc <= 28:
length = 12
default:
length = 16
}

buf := make([]byte, length+1)
// first byte length written by typeInfo.writer
res.ti.Size = int(length) + 1
// second byte sign
if value < 0 {
buf[0] = 0
} else {
buf[0] = 1
}

ub := dec.UnscaledBytes()
l := len(ub)
if l > int(length) {
err = fmt.Errorf("decimal out of range: %s", dec)
return res, err
}
// reverse the bytes
for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
buf[i] = ub[j]
}
res.buffer = buf
case typeBigVarBin:
switch val := val.(type) {
case []byte:
res.ti.Size = len(val)
res.buffer = val
default:
err = fmt.Errorf("mssql: invalid type for Binary column", val)
err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
return
}

default:
err = fmt.Errorf("mssql: type %x not implemented!", col.ti.TypeId)
err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
}
return

Expand Down
19 changes: 13 additions & 6 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ func TestBulkcopy(t *testing.T) {
{"test_bigint", 9223372036854775807},
{"test_bigintn", nil},
{"test_geom", geom},
//{"test_smallmoney", nil},
//{"test_money", nil},
//{"test_decimal_18_0", nil},
//{"test_decimal_9_2", nil},
//{"test_decimal_18_0", nil},
// {"test_smallmoney", 1234.56},
// {"test_money", 1234.56},
{"test_decimal_18_0", 1234.0001},
{"test_decimal_9_2", 1234.560001},
{"test_decimal_20_0", 1234.0001},
{"test_numeric_30_10", 1234567.1234567},
}

columns := make([]string, len(testValues))
Expand Down Expand Up @@ -116,7 +117,7 @@ func TestBulkcopy(t *testing.T) {
}
for i, c := range testValues {
if !compareValue(container[i], c.val) {
t.Errorf("columns %s : %s != %s\n", c.colname, container[i], c.val)
t.Errorf("columns %s : %s != %v\n", c.colname, container[i], c.val)
}
}
}
Expand All @@ -134,6 +135,11 @@ func compareValue(a interface{}, expected interface{}) bool {
case int64:
return int64(expected) == a
case float64:
if got, ok := a.([]uint8); ok {
var nf sql.NullFloat64
nf.Scan(got)
a = nf.Float64
}
return math.Abs(expected-a.(float64)) < 0.0001
default:
return reflect.DeepEqual(expected, a)
Expand Down Expand Up @@ -178,6 +184,7 @@ func setupTable(conn *sql.DB, tableName string) {
[test_decimal_18_0] [decimal](18, 0) NULL,
[test_decimal_9_2] [decimal](9, 2) NULL,
[test_decimal_20_0] [decimal](20, 0) NULL,
[test_numeric_30_10] [decimal](30, 10) NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
Expand Down
22 changes: 19 additions & 3 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ func (d Decimal) ToFloat64() float64 {
return val
}

const autoScale = 100

func Float64ToDecimal(f float64) (Decimal, error) {
return Float64ToDecimalScale(f, autoScale)
}

func Float64ToDecimalScale(f float64, scale uint8) (Decimal, error) {
var dec Decimal
if math.IsNaN(f) {
return dec, errors.New("NaN")
Expand All @@ -49,10 +55,10 @@ func Float64ToDecimal(f float64) (Decimal, error) {
}
dec.prec = 20
var integer float64
for dec.scale = 0; dec.scale <= 20; dec.scale++ {
for dec.scale = 0; dec.scale <= scale; dec.scale++ {
integer = f * scaletblflt64[dec.scale]
_, frac := math.Modf(integer)
if frac == 0 {
if frac == 0 && scale == autoScale {
break
}
}
Expand All @@ -73,7 +79,7 @@ func init() {
}
}

func (d Decimal) Bytes() []byte {
func (d Decimal) BigInt() big.Int {
bytes := make([]byte, 16)
binary.BigEndian.PutUint32(bytes[0:4], d.integer[3])
binary.BigEndian.PutUint32(bytes[4:8], d.integer[2])
Expand All @@ -84,9 +90,19 @@ func (d Decimal) Bytes() []byte {
if !d.positive {
x.Neg(&x)
}
return x
}

func (d Decimal) Bytes() []byte {
x := d.BigInt()
return scaleBytes(x.String(), d.scale)
}

func (d Decimal) UnscaledBytes() []byte {
x := d.BigInt()
return x.Bytes()
}

func scaleBytes(s string, scale uint8) []byte {
z := make([]byte, 0, len(s)+1)
if s[0] == '-' || s[0] == '+' {
Expand Down
3 changes: 2 additions & 1 deletion mssql_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
}

func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
return c.prepareCopyIn(query)
}

return c.prepareContext(ctx, query)
}

Expand Down
16 changes: 9 additions & 7 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"fmt"
"io"
"math"
"reflect"
"strconv"
"time"
"reflect"
)

// fixed-length data types
Expand Down Expand Up @@ -252,7 +252,7 @@ func decodeDateTime(buf []byte) time.Time {
0, 0, secs, ns, time.UTC)
}

func readFixedType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
r.ReadFull(ti.Buffer)
buf := ti.Buffer
switch ti.TypeId {
Expand Down Expand Up @@ -286,7 +286,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) (interface{}) {
panic("shoulnd't get here")
}

func readByteLenType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.byte()
if size == 0 {
return nil
Expand Down Expand Up @@ -381,7 +381,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}

func readShortLenType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint16()
if size == 0xffff {
return nil
Expand Down Expand Up @@ -424,7 +424,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}

func readLongLenType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} {
// information about this format can be found here:
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
// and here:
Expand Down Expand Up @@ -485,7 +485,7 @@ func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {

// reads variant value
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
func readVariantType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.int32()
if size == 0 {
return nil
Expand Down Expand Up @@ -577,7 +577,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) (interface{}) {

// partially length prefixed stream
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
func readPLPType(ti *typeInfo, r *tdsBuffer) (interface{}) {
func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint64()
var buf *bytes.Buffer
switch size {
Expand Down Expand Up @@ -1001,6 +1001,8 @@ func makeDecl(ti typeInfo) string {
}
case typeDecimal, typeDecimalN:
return fmt.Sprintf("decimal(%d, %d)", ti.Prec, ti.Scale)
case typeNumeric, typeNumericN:
return fmt.Sprintf("numeric(%d, %d)", ti.Prec, ti.Scale)
case typeMoney4:
return "smallmoney"
case typeMoney:
Expand Down