diff --git a/client/stmt.go b/client/stmt.go index cd64f3524..3c44c6d50 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -156,6 +156,9 @@ func (s *Stmt) write(args ...interface{}) error { case []byte: paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING} paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) + case mysql.TypedBytes: + paramTypes[i] = []byte{v.Type} + paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v.Bytes))), v.Bytes...) case json.RawMessage: paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING} paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) diff --git a/mysql/typed_bytes.go b/mysql/typed_bytes.go new file mode 100644 index 000000000..109d99cbf --- /dev/null +++ b/mysql/typed_bytes.go @@ -0,0 +1,8 @@ +package mysql + +// TypedBytes preserves the original MySQL type alongside the raw bytes +// for binary protocol parameters that are length-encoded. +type TypedBytes struct { + Type byte // Original MySQL type + Bytes []byte // Raw bytes +} diff --git a/server/stmt.go b/server/stmt.go index 553c9e695..76e86a296 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -300,7 +300,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) } if !isNull { - args[i] = v + args[i] = mysql.TypedBytes{Type: tp, Bytes: v} continue } else { args[i] = nil diff --git a/server/stmt_test.go b/server/stmt_test.go index 935597f68..63d3d2d35 100644 --- a/server/stmt_test.go +++ b/server/stmt_test.go @@ -97,3 +97,54 @@ func TestStmtPrepareWithPreparedStmt(t *testing.T) { require.NoError(t, err) require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type) } + +func TestBindStmtArgsTypedBytes(t *testing.T) { + testcases := []struct { + name string + paramType byte + paramValue []byte + expectType byte + expectBytes []byte + }{ + { + name: "DATETIME", + paramType: mysql.MYSQL_TYPE_DATETIME, + paramValue: []byte{0x07, 0xe8, 0x07, 0x06, 0x0f, 0x0e, 0x1e, 0x2d}, + expectType: mysql.MYSQL_TYPE_DATETIME, + expectBytes: []byte{0xe8, 0x07, 0x06, 0x0f, 0x0e, 0x1e, 0x2d}, + }, + { + name: "VARCHAR", + paramType: mysql.MYSQL_TYPE_VARCHAR, + paramValue: []byte{0x05, 'h', 'e', 'l', 'l', 'o'}, + expectType: mysql.MYSQL_TYPE_VARCHAR, + expectBytes: []byte("hello"), + }, + { + name: "BLOB", + paramType: mysql.MYSQL_TYPE_BLOB, + paramValue: []byte{0x04, 0x00, 0x01, 0x02, 0x03}, + expectType: mysql.MYSQL_TYPE_BLOB, + expectBytes: []byte{0x00, 0x01, 0x02, 0x03}, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + c := &Conn{} + s := &Stmt{Args: make([]interface{}, 1)} + s.Params = 1 + + nullBitmap := []byte{0x00} + paramTypes := []byte{tc.paramType, 0x00} + + err := c.bindStmtArgs(s, nullBitmap, paramTypes, tc.paramValue) + require.NoError(t, err) + + tv, ok := s.Args[0].(mysql.TypedBytes) + require.True(t, ok, "expected TypedBytes, got %T", s.Args[0]) + require.Equal(t, tc.expectType, tv.Type) + require.Equal(t, tc.expectBytes, tv.Bytes) + }) + } +}