Skip to content

[DB-40533] Add support for new type VECTOR(<dim>,DOUBLE) #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions pynuodb/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@

__all__ = ['Date', 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks',
'TimestampFromTicks', 'DateToTicks', 'TimeToTicks',
'TimestampToTicks', 'Binary', 'STRING', 'BINARY', 'NUMBER',
'DATETIME', 'ROWID', 'TypeObjectFromNuodb']
'TimestampToTicks', 'Binary', 'Vector', 'STRING', 'BINARY', 'NUMBER',
'DATETIME', 'ROWID', 'VECTOR_DOUBLE', 'TypeObjectFromNuodb']

import sys
import decimal
from datetime import datetime as Timestamp, date as Date, time as Time
from datetime import timedelta as TimeDelta
from datetime import tzinfo # pylint: disable=unused-import

from pynuodb import protocol

try:
from typing import Tuple, Union # pylint: disable=unused-import
except ImportError:
Expand Down Expand Up @@ -279,10 +281,37 @@ def __cmp__(self, other):
return -1


class Vector(list):
"""A specific type for SQL VECTOR(<dim>, <subtype>)
to be able to detect the desired type when binding parameters.
Apart from creating the value as a Vector with subtype
this can be used as a list."""
DOUBLE = protocol.VECTOR_DOUBLE

def __init__(self, subtype, *args, **kwargs):
if args:
if subtype != Vector.DOUBLE:
raise TypeError("Vector type only supported for subtype DOUBLE")

self.subtype = subtype

# forward the remaining arguments to the list __init__
super(Vector, self).__init__(*args, **kwargs)
else:
raise TypeError("Vector needs to be initialized with a subtype like Vector.DOUBLE as"
" first argument")

def getSubtype(self):
# type: () -> int
"""Returns the subtype of vector this instance holds data for"""
return self.subtype


STRING = TypeObject(str)
BINARY = TypeObject(str)
NUMBER = TypeObject(int, decimal.Decimal)
DATETIME = TypeObject(Timestamp, Date, Time)
VECTOR_DOUBLE = TypeObject(list)
ROWID = TypeObject()
NULL = TypeObject(None)

Expand All @@ -309,6 +338,7 @@ def __cmp__(self, other):
"timestamp without time zone": DATETIME,
"timestamp with time zone": DATETIME,
"time without time zone": DATETIME,
"vector double": VECTOR_DOUBLE,
# Old types used by NuoDB <2.0.3
"binarystring": BINARY,
"binaryvaryingstring": BINARY,
Expand Down
71 changes: 71 additions & 0 deletions pynuodb/encodedsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,39 @@ def putScaledCount2(self, value):
self.__output += data
return self

def putVectorDouble(self, value):
# type: (datatype.Vector) -> EncodedSession
"""Append a Vector with subtype Vector.DOUBLE to the message.

:type value: datatype.Vector
"""
self.__output.append(protocol.VECTOR)
# subtype
self.__output.append(protocol.VECTOR_DOUBLE)
# length in bytes in count notation, i.e. first
# number of bytes needed for the length, then the
# encoded length
lengthStr = crypt.toByteString(len(value) * 8)
self.__output.append(len(lengthStr))
self.__output += lengthStr

# the actual vector: Each value as double in little endian encoding
for val in value:
self.__output += struct.pack('<d', float(val))

return self

def putVector(self, value):
# type: (datatype.Vector) -> EncodedSession
"""Append a Vector type to the message.

:type value: datatype.Vector
"""
if value.getSubtype() == datatype.Vector.DOUBLE:
return self.putVectorDouble(value)

raise DataError("unsupported value for VECTOR subtype: %d" % (value.getSubtype()))

def putValue(self, value): # pylint: disable=too-many-return-statements
# type: (Any) -> EncodedSession
"""Call the supporting function based on the type of the value."""
Expand Down Expand Up @@ -854,6 +887,11 @@ def putValue(self, value): # pylint: disable=too-many-return-statements
if isinstance(value, bool):
return self.putBoolean(value)

# we don't want to autodetect lists as being VECTOR, so we
# only bind double if it is the explicit type
if isinstance(value, datatype.Vector):
return self.putVector(value)

# I find it pretty bogus that we pass str(value) here: why not value?
return self.putString(str(value))

Expand Down Expand Up @@ -1096,6 +1134,36 @@ def getUUID(self):

raise DataError('Not a UUID')

def getVector(self):
# type: () -> datatype.Vector
"""Read the next vector off the session.

:rtype datatype.Vector
"""
if self._getTypeCode() == protocol.VECTOR:
subtype = crypt.fromByteString(self._takeBytes(1))
if subtype == protocol.VECTOR_DOUBLE:
# VECTOR(<dim>, DOUBLE)
lengthBytes = crypt.fromByteString(self._takeBytes(1))
length = crypt.fromByteString(self._takeBytes(lengthBytes))

if length % 8 != 0:
raise DataError("Invalid size for VECTOR DOUBLE data: %d" % (length))

dimension = length // 8

# VECTOR DOUBLE stores the data as little endian
vector = datatype.Vector(datatype.Vector.DOUBLE,
[struct.unpack('<d', self._takeBytes(8))[0]
for _ in range(dimension)])

return vector
else:
raise DataError("Unknown VECTOR type: %d" % (subtype))
return 1

raise DataError('Not a VECTOR')

def getScaledCount2(self):
# type: () -> decimal.Decimal
"""Read a scaled and signed decimal from the session.
Expand Down Expand Up @@ -1171,6 +1239,9 @@ def getValue(self):
if code == protocol.UUID:
return self.getUUID()

if code == protocol.VECTOR:
return self.getVector()

if code == protocol.SCALEDCOUNT2:
return self.getScaledCount2()

Expand Down
5 changes: 4 additions & 1 deletion pynuodb/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
BLOBLEN4 = 193
CLOBLEN0 = 194
CLOBLEN4 = 198
SCALEDCOUNT1 = 199
VECTOR = 199
UUID = 200
SCALEDDATELEN0 = 200
SCALEDDATELEN1 = 201
Expand All @@ -66,6 +66,9 @@
DEBUGBARRIER = 240
SCALEDTIMESTAMPNOTZ = 241

# subtypes of the VECTOR type
VECTOR_DOUBLE = 0

# Protocol Messages
FAILURE = 0
OPENDATABASE = 3
Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,20 @@ def database(ap, db, te):
'user': db[1],
'password': db[2],
'options': {'schema': 'test'}} # type: DATABASE_FIXTURE
system_information = {'effective_version': 0}

try:
while True:
try:
conn = pynuodb.connect(**connect_args)
cursor = conn.cursor()
try:
cursor.execute("select GETEFFECTIVEPLATFORMVERSION() from system.dual")
row = cursor.fetchone()
system_information['effective_version'] = row[0]
finally:
cursor.close()

break
except pynuodb.session.SessionException:
pass
Expand All @@ -315,4 +325,4 @@ def database(ap, db, te):

_log.info("Database %s is available", db[0])

return connect_args
return {'connect_args': connect_args, 'system_information': system_information}
6 changes: 4 additions & 2 deletions tests/nuodb_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ class NuoBase(object):
driver = pynuodb # type: Any

connect_args = ()
system_information = ()
host = None

lower_func = 'lower' # For stored procedure test

@pytest.fixture(autouse=True)
def _setup(self, database):
# Preserve the options we'll need to create a connection to the DB
self.connect_args = database
self.connect_args = database['connect_args']
self.system_information = database['system_information']

# Verify the database is up and has a running TE
dbname = database['database']
dbname = self.connect_args['database']
(ret, out) = nuocmd(['--show-json', 'get', 'processes',
'--db-name', dbname], logout=False)
assert ret == 0, "DB not running: %s" % (out)
Expand Down
83 changes: 83 additions & 0 deletions tests/nuodb_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import decimal
import datetime

import pynuodb

from . import nuodb_base
from .mock_tzs import localize

Expand Down Expand Up @@ -126,3 +129,83 @@ def test_null_type(self):
assert len(row) == 1
assert cursor.description[0][1] == null_type
assert row[0] is None

def test_vector_type(self):
con = self._connect()
cursor = con.cursor()

# only activate this tests if tested against version 8 or above
if self.system_information['effective_version'] < 1835008:
return

cursor.execute("CREATE TEMPORARY TABLE tmp ("
" vec3 VECTOR(3, DOUBLE),"
" vec5 VECTOR(5, DOUBLE))")

cursor.execute("INSERT INTO tmp VALUES ("
" '[1.1,2.2,33.33]',"
" '[-1,2,-3,4,-5]')")

cursor.execute("SELECT * FROM tmp")

# check metadata
[name, type, _, _, precision, scale, _] = cursor.description[0]
assert name == "VEC3"
assert type == pynuodb.VECTOR_DOUBLE
assert precision == 3
assert scale == 0

[name, type, _, _, precision, scale, _] = cursor.description[1]
assert name == "VEC5"
assert type == pynuodb.VECTOR_DOUBLE
assert precision == 5
assert scale == 0

# check content
row = cursor.fetchone()
assert len(row) == 2
assert row[0] == [1.1, 2.2, 33.33]
assert row[1] == [-1, 2, -3, 4, -5]
assert cursor.fetchone() is None

# check this is actually a Vector type, not just a list
assert isinstance(row[0], pynuodb.Vector)
assert row[0].getSubtype() == pynuodb.Vector.DOUBLE
assert isinstance(row[1], pynuodb.Vector)
assert row[1].getSubtype() == pynuodb.Vector.DOUBLE

# check prepared parameters
parameters = [pynuodb.Vector(pynuodb.Vector.DOUBLE, [11.11, -2.2, 3333.333]),
pynuodb.Vector(pynuodb.Vector.DOUBLE, [-1.23, 2.345, -0.34, 4, -5678.9])]
cursor.execute("TRUNCATE TABLE tmp")
cursor.execute("INSERT INTO tmp VALUES (?, ?)", parameters)

cursor.execute("SELECT * FROM tmp")

# check content
row = cursor.fetchone()
assert len(row) == 2
assert row[0] == parameters[0]
assert row[1] == parameters[1]
assert cursor.fetchone() is None

# check that the inserted values are interpreted correctly by the database
cursor.execute("SELECT CAST(vec3 AS STRING) || ' - ' || CAST(vec5 AS STRING) AS strRep"
" FROM tmp")

row = cursor.fetchone()
assert len(row) == 1
assert row[0] == "[11.11,-2.2,3333.333] - [-1.23,2.345,-0.34,4,-5678.9]"
assert cursor.fetchone() is None

# currently binding a list also works - this is done via implicit string
# conversion of the passed argument in default bind case
parameters = [[11.11, -2.2, 3333.333]]
cursor.execute("SELECT VEC3 = ? FROM tmp", parameters)

# check content
row = cursor.fetchone()
assert len(row) == 1
assert row[0] is True
assert cursor.fetchone() is None