Skip to content

Commit 1920c42

Browse files
Correct support for booleans in data frames.
1 parent 6dcf2b4 commit 1920c42

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

doc/src/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Common Changes
6161
and :meth:`Connection.fetch_df_batches()`:
6262

6363
- Added support for CLOB, BLOB and RAW data types
64+
- Fixed support for BOOLEAN data type
6465
- Fixed bug when NUMBER data is fetched that does not have a precision or
6566
scale specified and :attr:`defaults.fetch_decimals` is set to *True*.
6667
- More efficient processing when a significant amount of data is duplicated

src/oracledb/impl/base/converters.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ cdef int convert_oracle_data_to_arrow(OracleMetadata from_metadata,
232232
arrow_array.append_double(data.buffer.as_double)
233233
elif arrow_type == NANOARROW_TYPE_FLOAT:
234234
arrow_array.append_float(data.buffer.as_float)
235+
elif arrow_type == NANOARROW_TYPE_BOOL:
236+
arrow_array.append_int64(data.buffer.as_bool)
235237
elif arrow_type in (
236238
NANOARROW_TYPE_BINARY,
237239
NANOARROW_TYPE_STRING,

src/oracledb/interchange/nanoarrow_bridge.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ cdef extern from "nanoarrow/nanoarrow.c":
8989
ArrowErrorCode ArrowArrayViewSetArray(ArrowArrayView* array_view,
9090
const ArrowArray* array,
9191
ArrowError* error)
92+
int8_t ArrowBitGet(const uint8_t* bits, int64_t i)
9293
void ArrowSchemaInit(ArrowSchema* schema)
9394
ArrowErrorCode ArrowSchemaInitFromType(ArrowSchema* schema, ArrowType type)
9495
ArrowErrorCode ArrowSchemaSetTypeDateTime(ArrowSchema* schema,
@@ -277,6 +278,7 @@ cdef class OracleArrowArray:
277278
int32_t *as_int32
278279
double *as_double
279280
float *as_float
281+
int8_t as_bool
280282
int64_t index
281283
uint8_t *ptr
282284
void* temp
@@ -295,6 +297,10 @@ cdef class OracleArrowArray:
295297
data_buffer = ArrowArrayBuffer(array.arrow_array, 1)
296298
as_float = <float*> data_buffer.data
297299
self.append_double(as_float[index])
300+
elif array.arrow_type == NANOARROW_TYPE_BOOL:
301+
data_buffer = ArrowArrayBuffer(array.arrow_array, 1)
302+
as_bool = ArrowBitGet(data_buffer.data, index)
303+
self.append_int64(as_bool)
298304
elif array.arrow_type == NANOARROW_TYPE_DECIMAL128:
299305
data_buffer = ArrowArrayBuffer(array.arrow_array, 1)
300306
ArrowDecimalInit(&decimal, 128, self.precision, self.scale)

tests/test_8000_dataframe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"""
2828
import datetime
2929
import decimal
30+
import unittest
3031

3132
import oracledb
3233

@@ -580,6 +581,36 @@ def test_8025(self):
580581
fetched_data = self.__get_data_from_df(fetched_df)
581582
self.assertEqual(fetched_data, data)
582583

584+
@unittest.skipUnless(
585+
test_env.get_client_version() >= (23, 1), "unsupported client"
586+
)
587+
@unittest.skipUnless(
588+
test_env.get_server_version() >= (23, 1), "unsupported server"
589+
)
590+
def test_8026(self):
591+
"8026 - fetch boolean"
592+
data = [(True,), (False,), (False,), (True,), (True,)]
593+
self.__check_interop()
594+
ora_df = self.conn.fetch_df_all(
595+
"""
596+
select true
597+
union all
598+
select false
599+
union all
600+
select false
601+
union all
602+
select true
603+
union all
604+
select true
605+
"""
606+
)
607+
fetched_tab = pyarrow.Table.from_arrays(
608+
ora_df.column_arrays(), names=ora_df.column_names()
609+
)
610+
fetched_df = fetched_tab.to_pandas()
611+
fetched_data = self.__get_data_from_df(fetched_df)
612+
self.assertEqual(fetched_data, data)
613+
583614

584615
if __name__ == "__main__":
585616
test_env.run_test_cases()

tests/test_8100_dataframe_async.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,33 @@ async def test_8121(self):
536536
fetched_data = self.__get_data_from_df(fetched_df)
537537
self.assertEqual(fetched_data, data)
538538

539+
@unittest.skipUnless(
540+
test_env.get_server_version() >= (23, 1), "unsupported server"
541+
)
542+
async def test_8122(self):
543+
"8122 - fetch boolean"
544+
data = [(True,), (False,), (False,), (True,), (True,)]
545+
self.__check_interop()
546+
ora_df = await self.conn.fetch_df_all(
547+
"""
548+
select true
549+
union all
550+
select false
551+
union all
552+
select false
553+
union all
554+
select true
555+
union all
556+
select true
557+
"""
558+
)
559+
fetched_tab = pyarrow.Table.from_arrays(
560+
ora_df.column_arrays(), names=ora_df.column_names()
561+
)
562+
fetched_df = fetched_tab.to_pandas()
563+
fetched_data = self.__get_data_from_df(fetched_df)
564+
self.assertEqual(fetched_data, data)
565+
539566

540567
if __name__ == "__main__":
541568
test_env.run_test_cases()

0 commit comments

Comments
 (0)