Skip to content

Commit 417fd2c

Browse files
committed
Added method binary_copy_to_table
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 013db21 commit 417fd2c

File tree

11 files changed

+958
-14
lines changed

11 files changed

+958
-14
lines changed

Cargo.lock

Lines changed: 544 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,22 @@ uuid = { version = "1.7.0", features = ["v4"] }
2828
serde_json = "1.0.113"
2929
futures-util = "0.3.30"
3030
macaddr = "1.0.1"
31+
postgres-types = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "psqlpy", features = [
32+
"derive",
33+
] }
3134
tokio-postgres = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "psqlpy", features = [
3235
"with-serde_json-1",
3336
"array-impls",
3437
"with-chrono-0_4",
3538
"with-uuid-1",
3639
] }
37-
postgres-types = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "psqlpy", features = [
38-
"derive",
39-
] }
4040
postgres-protocol = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "psqlpy" }
4141
postgres-openssl = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "psqlpy" }
4242
rust_decimal = { git = "https://github.com/chandr-andr/rust-decimal.git", branch = "psqlpy", features = [
4343
"db-postgres",
4444
"db-tokio-postgres",
4545
] }
4646
openssl = { version = "0.10.64", features = ["vendored"] }
47+
pgpq = "0.9.0"
48+
arrow-schema = "52.1.0"
49+
arrow = { version = "52.1.0", features = ["pyarrow"] }

python/psqlpy/_internal/__init__.pyi

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import types
22
from enum import Enum
3+
from io import BytesIO
34
from ipaddress import IPv4Address, IPv6Address
45
from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union
56

6-
from typing_extensions import Self
7+
from typing_extensions import Buffer, Self
78

89
_CustomClass = TypeVar(
910
"_CustomClass",
@@ -809,6 +810,30 @@ class Transaction:
809810
await cursor.close()
810811
```
811812
"""
813+
async def binary_copy_to_table(
814+
self: Self,
815+
source: Union[bytes, bytearray, Buffer, BytesIO],
816+
table_name: str,
817+
columns: Optional[Sequence[str]] = None,
818+
schema_name: Optional[str] = None,
819+
) -> int:
820+
"""Perform binary copy to PostgreSQL.
821+
822+
Execute `COPY table_name (<columns>) FROM STDIN (FORMAT binary)`
823+
and start sending bytes to PostgreSQL.
824+
825+
IMPORTANT! User is responsible for the bytes passed to the database.
826+
If bytes are incorrect user will get error from the database.
827+
828+
### Parameters:
829+
- `source`: source of bytes.
830+
- `table_name`: name of the table.
831+
- `columns`: sequence of str columns.
832+
- `schema_name`: name of the schema.
833+
834+
### Returns:
835+
number of inserted rows;
836+
"""
812837

813838
class Connection:
814839
"""Connection from Database Connection Pool.
@@ -1053,6 +1078,30 @@ class Connection:
10531078
It necessary to commit all transactions and close all cursor
10541079
made by this connection. Otherwise, it won't have any practical usage.
10551080
"""
1081+
async def binary_copy_to_table(
1082+
self: Self,
1083+
source: Union[bytes, bytearray, Buffer, BytesIO],
1084+
table_name: str,
1085+
columns: Optional[Sequence[str]] = None,
1086+
schema_name: Optional[str] = None,
1087+
) -> int:
1088+
"""Perform binary copy to PostgreSQL.
1089+
1090+
Execute `COPY table_name (<columns>) FROM STDIN (FORMAT binary)`
1091+
and start sending bytes to PostgreSQL.
1092+
1093+
IMPORTANT! User is responsible for the bytes passed to the database.
1094+
If bytes are incorrect user will get error from the database.
1095+
1096+
### Parameters:
1097+
- `source`: source of bytes.
1098+
- `table_name`: name of the table.
1099+
- `columns`: sequence of str columns.
1100+
- `schema_name`: name of the schema.
1101+
1102+
### Returns:
1103+
number of inserted rows;
1104+
"""
10561105

10571106
class ConnectionPoolStatus:
10581107
max_size: int

python/tests/test_connection.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from __future__ import annotations
22

3+
import os
34
import typing
5+
from io import BytesIO
46

57
import pytest
8+
from pgpq import ArrowToPostgresBinaryEncoder
9+
from pyarrow import parquet
610
from tests.helpers import count_rows_in_test_table
711

812
from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction
@@ -180,3 +184,55 @@ async def test_closed_connection_error(
180184

181185
with pytest.raises(expected_exception=ConnectionClosedError):
182186
await connection.execute("SELECT 1")
187+
188+
189+
async def test_binary_copy_to_table(
190+
psql_pool: ConnectionPool,
191+
) -> None:
192+
"""Test binary copy in connection."""
193+
table_name: typing.Final = "cars"
194+
await psql_pool.execute(f"DROP TABLE IF EXISTS {table_name}")
195+
await psql_pool.execute(
196+
"""
197+
CREATE TABLE IF NOT EXISTS cars (
198+
model VARCHAR,
199+
mpg FLOAT8,
200+
cyl INTEGER,
201+
disp FLOAT8,
202+
hp INTEGER,
203+
drat FLOAT8,
204+
wt FLOAT8,
205+
qsec FLOAT8,
206+
vs INTEGER,
207+
am INTEGER,
208+
gear INTEGER,
209+
carb INTEGER
210+
);
211+
""",
212+
)
213+
214+
arrow_table = parquet.read_table(
215+
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/MTcars.parquet", # noqa: PTH120, PTH100
216+
)
217+
encoder = ArrowToPostgresBinaryEncoder(arrow_table.schema)
218+
buf = BytesIO()
219+
buf.write(encoder.write_header())
220+
for batch in arrow_table.to_batches():
221+
buf.write(encoder.write_batch(batch))
222+
buf.write(encoder.finish())
223+
buf.seek(0)
224+
225+
async with psql_pool.acquire() as connection:
226+
inserted_rows = await connection.binary_copy_to_table(
227+
source=buf,
228+
table_name=table_name,
229+
)
230+
231+
expected_inserted_row: typing.Final = 32
232+
233+
assert inserted_rows == expected_inserted_row
234+
235+
real_table_rows: typing.Final = await psql_pool.execute(
236+
f"SELECT COUNT(*) AS rows_count FROM {table_name}",
237+
)
238+
assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row

python/tests/test_data/MTcars.parquet

2.86 KB
Binary file not shown.

python/tests/test_transaction.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from __future__ import annotations
22

3+
import os
34
import typing
5+
from io import BytesIO
46

57
import pytest
8+
from pgpq import ArrowToPostgresBinaryEncoder
9+
from pyarrow import parquet
610
from tests.helpers import count_rows_in_test_table
711

812
from psqlpy import ConnectionPool, Cursor, IsolationLevel, ReadVariant
@@ -334,3 +338,55 @@ async def test_transaction_send_underlying_connection_to_pool_manually(
334338
await transaction.commit()
335339
assert not psql_pool.status().available
336340
assert psql_pool.status().available == 1
341+
342+
343+
async def test_binary_copy_to_table(
344+
psql_pool: ConnectionPool,
345+
) -> None:
346+
"""Test binary copy in transaction."""
347+
table_name: typing.Final = "cars"
348+
await psql_pool.execute(f"DROP TABLE IF EXISTS {table_name}")
349+
await psql_pool.execute(
350+
"""
351+
CREATE TABLE IF NOT EXISTS cars (
352+
model VARCHAR,
353+
mpg FLOAT8,
354+
cyl INTEGER,
355+
disp FLOAT8,
356+
hp INTEGER,
357+
drat FLOAT8,
358+
wt FLOAT8,
359+
qsec FLOAT8,
360+
vs INTEGER,
361+
am INTEGER,
362+
gear INTEGER,
363+
carb INTEGER
364+
);
365+
""",
366+
)
367+
368+
arrow_table = parquet.read_table(
369+
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/MTcars.parquet", # noqa: PTH120, PTH100
370+
)
371+
encoder = ArrowToPostgresBinaryEncoder(arrow_table.schema)
372+
buf = BytesIO()
373+
buf.write(encoder.write_header())
374+
for batch in arrow_table.to_batches():
375+
buf.write(encoder.write_batch(batch))
376+
buf.write(encoder.finish())
377+
buf.seek(0)
378+
379+
async with psql_pool.acquire() as connection:
380+
inserted_rows = await connection.binary_copy_to_table(
381+
source=buf,
382+
table_name=table_name,
383+
)
384+
385+
expected_inserted_row: typing.Final = 32
386+
387+
assert inserted_rows == expected_inserted_row
388+
389+
real_table_rows: typing.Final = await psql_pool.execute(
390+
f"SELECT COUNT(*) AS rows_count FROM {table_name}",
391+
)
392+
assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row

src/driver/common_options.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,22 @@ impl KeepaliveConfig {
119119
}
120120
}
121121
}
122+
123+
#[pyclass]
124+
#[derive(Clone, Copy)]
125+
pub enum CopyCommandFormat {
126+
TEXT,
127+
CSV,
128+
BINARY,
129+
}
130+
131+
impl CopyCommandFormat {
132+
#[must_use]
133+
pub fn to_internal(&self) -> String {
134+
match self {
135+
CopyCommandFormat::TEXT => "text".into(),
136+
CopyCommandFormat::CSV => "csv".into(),
137+
CopyCommandFormat::BINARY => "binary".into(),
138+
}
139+
}
140+
}

0 commit comments

Comments
 (0)