diff --git a/src/main/java/org/lmdbjava/Dbi.java b/src/main/java/org/lmdbjava/Dbi.java index ad1bb5a..4a4cc35 100644 --- a/src/main/java/org/lmdbjava/Dbi.java +++ b/src/main/java/org/lmdbjava/Dbi.java @@ -81,10 +81,10 @@ public final class Dbi { if (nativeCb) { this.ccb = (keyA, keyB) -> { - final T compKeyA = proxy.allocate(); - final T compKeyB = proxy.allocate(); - proxy.out(compKeyA, keyA, keyA.address()); - proxy.out(compKeyB, keyB, keyB.address()); + T compKeyA = proxy.allocate(); + T compKeyB = proxy.allocate(); + compKeyA = proxy.out(compKeyA, keyA, keyA.address()); + compKeyB = proxy.out(compKeyB, keyB, keyB.address()); final int result = this.comparator.compare(compKeyA, compKeyB); proxy.deallocate(compKeyA); proxy.deallocate(compKeyB); diff --git a/src/test/java/org/lmdbjava/ComparatorTest.java b/src/test/java/org/lmdbjava/ComparatorTest.java index 3e265ce..3c7e7a4 100644 --- a/src/test/java/org/lmdbjava/ComparatorTest.java +++ b/src/test/java/org/lmdbjava/ComparatorTest.java @@ -67,11 +67,12 @@ public static Object[] data() { final ComparatorRunner string = new StringRunner(); final ComparatorRunner db = new DirectBufferRunner(); final ComparatorRunner ba = new ByteArrayRunner(); + final ComparatorRunner baUnsigned = new UnsignedByteArrayRunner(); final ComparatorRunner bb = new ByteBufferRunner(); final ComparatorRunner netty = new NettyRunner(); final ComparatorRunner gub = new GuavaUnsignedBytes(); final ComparatorRunner gsb = new GuavaSignedBytes(); - return new Object[] {string, db, ba, bb, netty, gub, gsb}; + return new Object[] {string, db, ba, baUnsigned, bb, netty, gub, gsb}; } private static byte[] buffer(final int... bytes) { @@ -140,6 +141,16 @@ public int compare(final byte[] o1, final byte[] o2) { } } + /** Tests {@link ByteArrayProxy} (unsigned). */ + private static final class UnsignedByteArrayRunner implements ComparatorRunner { + + @Override + public int compare(final byte[] o1, final byte[] o2) { + final Comparator c = PROXY_BA.getUnsignedComparator(); + return c.compare(o1, o2); + } + } + /** Tests {@link ByteBufferProxy}. */ private static final class ByteBufferRunner implements ComparatorRunner { diff --git a/src/test/java/org/lmdbjava/DbiTest.java b/src/test/java/org/lmdbjava/DbiTest.java index 1fa80f6..75c2393 100644 --- a/src/test/java/org/lmdbjava/DbiTest.java +++ b/src/test/java/org/lmdbjava/DbiTest.java @@ -46,9 +46,7 @@ import static org.lmdbjava.KeyRange.atMost; import static org.lmdbjava.PutFlags.MDB_NODUPDATA; import static org.lmdbjava.PutFlags.MDB_NOOVERWRITE; -import static org.lmdbjava.TestUtils.DB_1; -import static org.lmdbjava.TestUtils.ba; -import static org.lmdbjava.TestUtils.bb; +import static org.lmdbjava.TestUtils.*; import java.io.File; import java.io.IOException; @@ -63,8 +61,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; -import org.agrona.concurrent.UnsafeBuffer; +import java.util.function.*; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -82,6 +79,7 @@ public final class DbiTest { @Rule public final TemporaryFolder tmp = new TemporaryFolder(); private Env env; + private Env envBa; @After public void after() { @@ -97,6 +95,13 @@ public void before() throws IOException { .setMaxReaders(2) .setMaxDbs(2) .open(path, MDB_NOSUBDIR); + final File pathBa = tmp.newFile(); + envBa = + create(PROXY_BA) + .setMapSize(MEBIBYTES.toBytes(64)) + .setMaxReaders(2) + .setMaxDbs(2) + .open(pathBa, MDB_NOSUBDIR); } @Test(expected = ConstantDerivedException.class) @@ -117,20 +122,41 @@ public void customComparator() { } return lexical * -1; }; - final Dbi db = env.openDbi(DB_1, reverseOrder, true, MDB_CREATE); - try (Txn txn = env.txnWrite()) { - assertThat(db.put(txn, bb(2), bb(3)), is(true)); - assertThat(db.put(txn, bb(4), bb(6)), is(true)); - assertThat(db.put(txn, bb(6), bb(7)), is(true)); - assertThat(db.put(txn, bb(8), bb(7)), is(true)); + doCustomComparator(env, reverseOrder, TestUtils::bb, ByteBuffer::getInt); + } + + @Test + public void customComparatorByteArray() { + final Comparator reverseOrder = + (o1, o2) -> { + final int lexical = PROXY_BA.getComparator().compare(o1, o2); + if (lexical == 0) { + return 0; + } + return lexical * -1; + }; + doCustomComparator(envBa, reverseOrder, TestUtils::ba, TestUtils::fromBa); + } + + private void doCustomComparator( + Env env, + Comparator comparator, + IntFunction serializer, + ToIntFunction deserializer) { + final Dbi db = env.openDbi(DB_1, comparator, true, MDB_CREATE); + try (Txn txn = env.txnWrite()) { + assertThat(db.put(txn, serializer.apply(2), serializer.apply(3)), is(true)); + assertThat(db.put(txn, serializer.apply(4), serializer.apply(6)), is(true)); + assertThat(db.put(txn, serializer.apply(6), serializer.apply(7)), is(true)); + assertThat(db.put(txn, serializer.apply(8), serializer.apply(7)), is(true)); txn.commit(); } - try (Txn txn = env.txnRead(); - CursorIterable ci = db.iterate(txn, atMost(bb(4)))) { - final Iterator> iter = ci.iterator(); - assertThat(iter.next().key().getInt(), is(8)); - assertThat(iter.next().key().getInt(), is(6)); - assertThat(iter.next().key().getInt(), is(4)); + try (Txn txn = env.txnRead(); + CursorIterable ci = db.iterate(txn, atMost(serializer.apply(4)))) { + final Iterator> iter = ci.iterator(); + assertThat(deserializer.applyAsInt(iter.next().key()), is(8)); + assertThat(deserializer.applyAsInt(iter.next().key()), is(6)); + assertThat(deserializer.applyAsInt(iter.next().key()), is(4)); } } @@ -143,9 +169,24 @@ public void dbOpenMaxDatabases() { @Test public void dbiWithComparatorThreadSafety() { + doDbiWithComparatorThreadSafety( + env, PROXY_OPTIMAL::getComparator, TestUtils::bb, ByteBuffer::getInt); + } + + @Test + public void dbiWithComparatorThreadSafetyByteArray() { + doDbiWithComparatorThreadSafety( + envBa, PROXY_BA::getComparator, TestUtils::ba, TestUtils::fromBa); + } + + public void doDbiWithComparatorThreadSafety( + Env env, + Function> comparator, + IntFunction serializer, + ToIntFunction deserializer) { final DbiFlags[] flags = new DbiFlags[] {MDB_CREATE, MDB_INTEGERKEY}; - final Comparator c = PROXY_OPTIMAL.getComparator(flags); - final Dbi db = env.openDbi(DB_1, c, true, flags); + final Comparator c = comparator.apply(flags); + final Dbi db = env.openDbi(DB_1, c, true, flags); final List keys = range(0, 1_000).boxed().collect(toList()); @@ -155,25 +196,25 @@ public void dbiWithComparatorThreadSafety() { pool.submit( () -> { while (proceed.get()) { - try (Txn txn = env.txnRead()) { - db.get(txn, bb(50)); + try (Txn txn = env.txnRead()) { + db.get(txn, serializer.apply(50)); } } }); for (final Integer key : keys) { - try (Txn txn = env.txnWrite()) { - db.put(txn, bb(key), bb(3)); + try (Txn txn = env.txnWrite()) { + db.put(txn, serializer.apply(key), serializer.apply(3)); txn.commit(); } } - try (Txn txn = env.txnRead(); - CursorIterable ci = db.iterate(txn)) { - final Iterator> iter = ci.iterator(); + try (Txn txn = env.txnRead(); + CursorIterable ci = db.iterate(txn)) { + final Iterator> iter = ci.iterator(); final List result = new ArrayList<>(); while (iter.hasNext()) { - result.add(iter.next().key().getInt()); + result.add(deserializer.applyAsInt(iter.next().key())); } assertThat(result, Matchers.contains(keys.toArray(new Integer[0]))); @@ -339,7 +380,7 @@ public void putCommitGetByteArray() throws IOException { try (Txn txn = envBa.txnWrite()) { final byte[] found = db.get(txn, ba(5)); assertNotNull(found); - assertThat(new UnsafeBuffer(txn.val()).getInt(0), is(5)); + assertThat(fromBa(txn.val()), is(5)); } } } diff --git a/src/test/java/org/lmdbjava/TestUtils.java b/src/test/java/org/lmdbjava/TestUtils.java index 42dcf05..c020326 100644 --- a/src/test/java/org/lmdbjava/TestUtils.java +++ b/src/test/java/org/lmdbjava/TestUtils.java @@ -36,9 +36,13 @@ final class TestUtils { private TestUtils() {} static byte[] ba(final int value) { - final MutableDirectBuffer b = new UnsafeBuffer(new byte[4]); - b.putInt(0, value); - return b.byteArray(); + byte[] bytes = new byte[4]; + ByteBuffer.wrap(bytes).putInt(value); + return bytes; + } + + static int fromBa(final byte[] ba) { + return ByteBuffer.wrap(ba).getInt(); } static ByteBuffer bb(final int value) {