diff --git a/main/filesystem-invariants-tests/src/test/java/org/cryptomator/filesystem/invariants/ConcurrencyTests.java b/main/filesystem-invariants-tests/src/test/java/org/cryptomator/filesystem/invariants/ConcurrencyTests.java index f1bb9282b..4389c6fa5 100644 --- a/main/filesystem-invariants-tests/src/test/java/org/cryptomator/filesystem/invariants/ConcurrencyTests.java +++ b/main/filesystem-invariants-tests/src/test/java/org/cryptomator/filesystem/invariants/ConcurrencyTests.java @@ -2,18 +2,14 @@ package org.cryptomator.filesystem.invariants; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeThat; import java.nio.ByteBuffer; -import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.SynchronousQueue; +import org.cryptomator.common.RunnableThrowingException; import org.cryptomator.filesystem.File; import org.cryptomator.filesystem.FileSystem; import org.cryptomator.filesystem.ReadableFile; @@ -45,7 +41,7 @@ public class ConcurrencyTests { public final ExpectedException thrown = ExpectedException.none(); @Theory - public void testConcurrentPartialReadsDontInterfere(FileSystemFactory fileSystemFactory, WayToObtainAFile wayToObtainAnExistingFile) throws InterruptedException, ExecutionException { + public void testConcurrentPartialReadsDontInterfere(FileSystemFactory fileSystemFactory, WayToObtainAFile wayToObtainAnExistingFile) throws ExecutionException { assumeThat(wayToObtainAnExistingFile.returnedFilesExist(), is(true)); FileSystem fileSystem = fileSystemFactory.create(); @@ -54,51 +50,87 @@ public class ConcurrencyTests { byte[] expectedData2 = new byte[] {44, 1, -3, 4}; File file = wayToObtainAnExistingFile.fileWithNameAndContent(fileSystem, FILE_NAME, originalData); - // control flag to make sure thread timing is synchronized correctly - AtomicInteger state = new AtomicInteger(); + TasksInThreadRunner thread1 = new TasksInThreadRunner(); + TasksInThreadRunner thread2 = new TasksInThreadRunner(); - // set position, then wait before read: + Holder readableFile1 = new Holder<>(); + Holder readableFile2 = new Holder<>(); byte[] actualData1 = new byte[3]; - Callable readTask1 = () -> { - try (ReadableFile readable = file.openReadable()) { - ByteBuffer buf = ByteBuffer.wrap(actualData1); - readable.position(3); - assertTrue("readTask1 must be the first to set its position", state.compareAndSet(0, 1)); - Thread.sleep(20); - assertTrue("readTask1 must be the last to actually read data", state.compareAndSet(3, 4)); - readable.read(buf); - return null; - } - }; - - // wait, then set position and read: byte[] actualData2 = new byte[4]; - Callable readTask2 = () -> { - try (ReadableFile readable = file.openReadable()) { - ByteBuffer buf = ByteBuffer.wrap(actualData2); - Thread.sleep(10); - assertTrue("readTask2 must be second to set its position", state.compareAndSet(1, 2)); - readable.position(1); - readable.read(buf); - assertTrue("readTask2 must be first to finish reading", state.compareAndSet(2, 3)); - return null; - } - }; - // start both read tasks at the same time: - ThreadPoolExecutor executor = new ThreadPoolExecutor(2, 2, 0, TimeUnit.SECONDS, new LinkedBlockingQueue()); - executor.prestartAllCoreThreads(); - Future task1Completed = executor.submit(readTask1); - Future task2Completed = executor.submit(readTask2); - task1Completed.get(); - task2Completed.get(); - executor.shutdown(); + thread1.runAndWaitFor(() -> readableFile1.value = file.openReadable()); + thread2.runAndWaitFor(() -> readableFile2.value = file.openReadable()); + thread1.runAndWaitFor(() -> readableFile1.value.position(3)); + thread2.runAndWaitFor(() -> { + readableFile2.value.position(1); + readableFile2.value.read(ByteBuffer.wrap(actualData2)); + }); + thread1.runAndWaitFor(() -> readableFile1.value.read(ByteBuffer.wrap(actualData1))); + thread1.runAndWaitFor(readableFile1.value::close); + thread2.runAndWaitFor(readableFile2.value::close); + + thread1.shutdown(); + thread2.shutdown(); assertArrayEquals(expectedData1, actualData1); assertArrayEquals(expectedData2, actualData2); } - private static class TaskRunningThread extends Thread { + private static class Holder { + + T value; + + } + + private static class TasksInThreadRunner { + + private final Runnable TERMINATION_HINT = () -> { + }; + + private final SynchronousQueue handoverQueue = new SynchronousQueue<>(); + private final Thread thread = new Thread(() -> { + Runnable task; + while (true) { + try { + task = handoverQueue.take(); + } catch (InterruptedException e) { + return; + } + if (task == TERMINATION_HINT) { + break; + } + task.run(); + } + }); + + public TasksInThreadRunner() { + thread.start(); + } + + public void runAndWaitFor(RunnableThrowingException task) throws ExecutionException { + CompletableFuture future = new CompletableFuture<>(); + try { + handoverQueue.put(() -> { + try { + task.run(); + future.complete(null); + } catch (Throwable e) { + future.completeExceptionally(e); + } + }); + future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + public void shutdown() { + try { + handoverQueue.put(TERMINATION_HINT); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } }