diff --git a/src/main/java/net/helenus/core/PostCommitFunction.java b/src/main/java/net/helenus/core/PostCommitFunction.java index 0c823cf..a7e152c 100644 --- a/src/main/java/net/helenus/core/PostCommitFunction.java +++ b/src/main/java/net/helenus/core/PostCommitFunction.java @@ -4,18 +4,17 @@ import java.util.List; import java.util.Objects; public class PostCommitFunction implements java.util.function.Function { + public static final PostCommitFunction NULL_ABORT = new PostCommitFunction<>(null, null, false); + public static final PostCommitFunction NULL_COMMIT = new PostCommitFunction<>(null, null, true); - private final UnitOfWork uow; private final List commitThunks; private final List abortThunks; private boolean committed; PostCommitFunction( - UnitOfWork uow, List postCommit, List abortThunks, boolean committed) { - this.uow = uow; this.commitThunks = postCommit; this.abortThunks = abortThunks; this.committed = committed; diff --git a/src/main/java/net/helenus/core/UnitOfWork.java b/src/main/java/net/helenus/core/UnitOfWork.java index 454a318..dfcdf37 100644 --- a/src/main/java/net/helenus/core/UnitOfWork.java +++ b/src/main/java/net/helenus/core/UnitOfWork.java @@ -22,17 +22,32 @@ import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import com.google.common.collect.TreeTraverser; import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.cache.Cache; import javax.cache.CacheManager; +import javax.cache.configuration.CacheEntryListenerConfiguration; +import javax.cache.configuration.Configuration; import javax.cache.integration.CacheLoader; import javax.cache.integration.CacheLoaderException; +import javax.cache.integration.CompletionListener; +import javax.cache.processor.EntryProcessor; +import javax.cache.processor.EntryProcessorException; +import javax.cache.processor.EntryProcessorResult; + import net.helenus.core.cache.CacheUtil; import net.helenus.core.cache.Facet; import net.helenus.core.cache.MapCache; @@ -48,14 +63,12 @@ import org.slf4j.LoggerFactory; /** Encapsulates the concept of a "transaction" as a unit-of-work. */ public class UnitOfWork implements AutoCloseable { - private static final Logger LOG = LoggerFactory.getLogger(UnitOfWork.class); - private static final Pattern classNameRegex = Pattern.compile("^(?:\\w+\\.)+(?:(\\w+)|(\\w+)\\$.*)$"); public final UnitOfWork parent; private final List nested = new ArrayList<>(); private final Table>> cache = HashBasedTable.create(); - private final MapCache statementCache; + private final EvictTrackingMapCache statementCache; protected final HelenusSession session; protected String purpose; protected List nestedPurposes = new ArrayList(); @@ -65,7 +78,7 @@ public class UnitOfWork implements AutoCloseable { protected int databaseLookups = 0; protected final Stopwatch elapsedTime; protected Map databaseTime = new HashMap<>(); - protected double cacheLookupTimeMSecs = 0.0; + protected double cacheLookupTimeMSecs = 0.0d; private List commitThunks = new ArrayList(); private List abortThunks = new ArrayList(); private List> asyncOperationFutures = new ArrayList>(); @@ -74,17 +87,6 @@ public class UnitOfWork implements AutoCloseable { private long committedAt = 0L; private BatchOperation batch; - private String extractClassNameFromStackFrame(String classNameOnStack) { - String name = null; - Matcher m = classNameRegex.matcher(classNameOnStack); - if (m.find()) { - name = (m.group(1) != null) ? m.group(1) : ((m.group(2) != null) ? m.group(2) : name); - } else { - name = classNameOnStack; - } - return name; - } - public UnitOfWork(HelenusSession session) { this(session, null); } @@ -97,7 +99,7 @@ public class UnitOfWork implements AutoCloseable { parent.addNestedUnitOfWork(this); } this.session = session; - CacheLoader cacheLoader = null; + CacheLoader cacheLoader = null; if (parent != null) { cacheLoader = new CacheLoader() { @@ -121,7 +123,7 @@ public class UnitOfWork implements AutoCloseable { }; } this.elapsedTime = Stopwatch.createUnstarted(); - this.statementCache = new MapCache(null, "UOW(" + hashCode() + ")", cacheLoader, true); + this.statementCache = new EvictTrackingMapCache(null, "UOW(" + hashCode() + ")", cacheLoader, true); } public void addDatabaseTime(String name, Stopwatch amount) { @@ -363,7 +365,7 @@ public class UnitOfWork implements AutoCloseable { throws HelenusException, TimeoutException { if (isDone()) { - return new PostCommitFunction(this, null, null, false); + return PostCommitFunction.NULL_ABORT; } // Only the outer-most UOW batches statements for commit time, execute them. @@ -399,7 +401,7 @@ public class UnitOfWork implements AutoCloseable { } } - return new PostCommitFunction(this, null, null, false); + return PostCommitFunction.NULL_ABORT; } else { committed = true; aborted = false; @@ -453,11 +455,11 @@ public class UnitOfWork implements AutoCloseable { LOG.info(logTimers("committed")); } - return new PostCommitFunction(this, null, null, true); + return PostCommitFunction.NULL_COMMIT; } else { - // Merge cache and statistics into parent if there is one. parent.statementCache.putAll(statementCache.unwrap(Map.class)); + parent.statementCache.removeAll(statementCache.getDeletions()); parent.mergeCache(cache); parent.addBatched(batch); if (purpose != null) { @@ -483,15 +485,15 @@ public class UnitOfWork implements AutoCloseable { // Constructor ctor = clazz.getConstructor(conflictExceptionClass); // T object = ctor.newInstance(new Object[] { String message }); // } - return new PostCommitFunction(this, commitThunks, abortThunks, true); + return new PostCommitFunction(commitThunks, abortThunks, true); } - private void addBatched(BatchOperation batch) { - if (batch != null) { + private void addBatched(BatchOperation batchArg) { + if (batchArg != null) { if (this.batch == null) { - this.batch = batch; + this.batch = batchArg; } else { - this.batch.addAll(batch); + this.batch.addAll(batchArg); } } } @@ -582,4 +584,221 @@ public class UnitOfWork implements AutoCloseable { public long committedAt() { return committedAt; } + + private static class EvictTrackingMapCache implements Cache { + private final Set deletes; + private final Cache delegate; + + public EvictTrackingMapCache(CacheManager manager, String name, CacheLoader cacheLoader, + boolean isReadThrough) { + deletes = Collections.synchronizedSet(new HashSet<>()); + delegate = new MapCache<>(manager, name, cacheLoader, isReadThrough); + } + + /** Non-interface method; should only be called by UnitOfWork when merging to an enclosing UnitOfWork. */ + public Set getDeletions() { + return new HashSet<>(deletes); + } + + @Override + public V get(K key) { + if (deletes.contains(key)) { + return null; + } + + return delegate.get(key); + } + + @Override + public Map getAll(Set keys) { + Set clonedKeys = new HashSet<>(keys); + clonedKeys.removeAll(deletes); + return delegate.getAll(clonedKeys); + } + + @Override + public boolean containsKey(K key) { + if (deletes.contains(key)) { + return false; + } + + return delegate.containsKey(key); + } + + @Override + public void loadAll(Set keys, boolean replaceExistingValues, CompletionListener listener) { + Set clonedKeys = new HashSet<>(keys); + clonedKeys.removeAll(deletes); + delegate.loadAll(clonedKeys, replaceExistingValues, listener); + } + + @Override + public void put(K key, V value) { + if (deletes.contains(key)) { + deletes.remove(key); + } + + delegate.put(key, value); + } + + @Override + public V getAndPut(K key, V value) { + if (deletes.contains(key)) { + deletes.remove(key); + } + + return delegate.getAndPut(key, value); + } + + @Override + public void putAll(Map map) { + deletes.removeAll(map.keySet()); + delegate.putAll(map); + } + + @Override + public synchronized boolean putIfAbsent(K key, V value) { + if (!delegate.containsKey(key) && deletes.contains(key)) { + deletes.remove(key); + } + + return delegate.putIfAbsent(key, value); + } + + @Override + public boolean remove(K key) { + boolean removed = delegate.remove(key); + deletes.add(key); + return removed; + } + + @Override + public boolean remove(K key, V value) { + boolean removed = delegate.remove(key, value); + if (removed) { + deletes.add(key); + } + + return removed; + } + + @Override + public V getAndRemove(K key) { + V value = delegate.getAndRemove(key); + deletes.add(key); + return value; + } + + @Override + public void removeAll(Set keys) { + Set cloneKeys = new HashSet<>(keys); + delegate.removeAll(cloneKeys); + deletes.addAll(cloneKeys); + } + + @Override + @SuppressWarnings("unchecked") + public synchronized void removeAll() { + Map impl = delegate.unwrap(Map.class); + Set keys = impl.keySet(); + delegate.removeAll(); + deletes.addAll(keys); + } + + @Override + public void clear() { + delegate.clear(); + deletes.clear(); + } + + @Override + public boolean replace(K key, V oldValue, V newValue) { + if (deletes.contains(key)) { + return false; + } + + return delegate.replace(key, oldValue, newValue); + } + + @Override + public boolean replace(K key, V value) { + if (deletes.contains(key)) { + return false; + } + + return delegate.replace(key, value); + } + + @Override + public V getAndReplace(K key, V value) { + if (deletes.contains(key)) { + return null; + } + + return delegate.getAndReplace(key, value); + } + + @Override + public > C getConfiguration(Class clazz) { + return delegate.getConfiguration(clazz); + } + + @Override + public T invoke(K key, EntryProcessor processor, Object... arguments) + throws EntryProcessorException { + if (deletes.contains(key)) { + return null; + } + + return delegate.invoke(key, processor, arguments); + } + + @Override + public Map> invokeAll(Set keys, EntryProcessor processor, + Object... arguments) { + Set clonedKeys = new HashSet<>(keys); + clonedKeys.removeAll(deletes); + return delegate.invokeAll(clonedKeys, processor, arguments); + } + + @Override + public String getName() { + return delegate.getName(); + } + + @Override + public CacheManager getCacheManager() { + return delegate.getCacheManager(); + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public boolean isClosed() { + return delegate.isClosed(); + } + + @Override + public T unwrap(Class clazz) { + return delegate.unwrap(clazz); + } + + @Override + public void registerCacheEntryListener(CacheEntryListenerConfiguration cacheEntryListenerConfiguration) { + delegate.registerCacheEntryListener(cacheEntryListenerConfiguration); + } + + @Override + public void deregisterCacheEntryListener(CacheEntryListenerConfiguration cacheEntryListenerConfiguration) { + delegate.deregisterCacheEntryListener(cacheEntryListenerConfiguration); + } + + @Override + public Iterator> iterator() { + return delegate.iterator(); + } + } } diff --git a/src/main/java/net/helenus/core/cache/MapCache.java b/src/main/java/net/helenus/core/cache/MapCache.java index 2cef0b6..72c0326 100644 --- a/src/main/java/net/helenus/core/cache/MapCache.java +++ b/src/main/java/net/helenus/core/cache/MapCache.java @@ -1,6 +1,5 @@ package net.helenus.core.cache; - import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -22,13 +21,13 @@ import javax.cache.processor.MutableEntry; public class MapCache implements Cache { private final CacheManager manager; private final String name; - private Map map = new ConcurrentHashMap(); - private Set cacheEntryRemovedListeners = new HashSet<>(); + private Map map = new ConcurrentHashMap<>(); + private Set> cacheEntryRemovedListeners = new HashSet<>(); private CacheLoader cacheLoader = null; private boolean isReadThrough = false; - private Configuration configuration = new MapConfiguration(); private static class MapConfiguration implements Configuration { + private static final long serialVersionUID = 6093947542772516209L; @Override public Class getKeyType() { @@ -54,6 +53,7 @@ public class MapCache implements Cache { this.isReadThrough = isReadThrough; } + /** {@inheritDoc} */ @Override public V get(K key) { @@ -77,23 +77,23 @@ public class MapCache implements Cache { Map result = null; synchronized (map) { result = new HashMap(keys.size()); - for (K key : keys) { + Iterator it = keys.iterator(); + while (it.hasNext()) { + K key = it.next(); V value = map.get(key); if (value != null) { result.put(key, value); - keys.remove(key); + it.remove(); } } - if (isReadThrough && cacheLoader != null) { - for (K key : keys) { - Map loadedValues = cacheLoader.loadAll(keys); - for (Map.Entry entry : loadedValues.entrySet()) { - V v = entry.getValue(); - if (v != null) { - K k = entry.getKey(); - map.put(k, v); - result.put(k, v); - } + if (keys.size() != 0 && isReadThrough && cacheLoader != null) { + Map loadedValues = cacheLoader.loadAll(keys); + for (Map.Entry entry : loadedValues.entrySet()) { + V v = entry.getValue(); + if (v != null) { + K k = entry.getKey(); + map.put(k, v); + result.put(k, v); } } } @@ -152,11 +152,10 @@ public class MapCache implements Cache { V result = null; synchronized (map) { result = map.get(key); - if (value == null && isReadThrough && cacheLoader != null) { + if (result == null && isReadThrough && cacheLoader != null) { V loadedValue = cacheLoader.load(key); if (loadedValue != null) { - map.put(key, value); - value = loadedValue; + result = loadedValue; } } map.put(key, value); @@ -266,11 +265,13 @@ public class MapCache implements Cache { @Override public void removeAll(Set keys) { synchronized (map) { - for (K key : keys) { + Iterator it = keys.iterator(); + while (it.hasNext()) { + K key = it.next(); if (map.containsKey(key)) { map.remove(key); } else { - keys.remove(key); + it.remove(); } } } @@ -306,6 +307,7 @@ public class MapCache implements Cache { @Override public T invoke(K key, EntryProcessor entryProcessor, Object... arguments) throws EntryProcessorException { + // TODO return null; } @@ -386,6 +388,7 @@ public class MapCache implements Cache { /** {@inheritDoc} */ @Override + @SuppressWarnings("unchecked") public T unwrap(Class clazz) { return (T) map; } diff --git a/src/test/java/net/helenus/test/integration/core/unitofwork/UnitOfWorkTest.java b/src/test/java/net/helenus/test/integration/core/unitofwork/UnitOfWorkTest.java index b08a3ae..e34e49f 100644 --- a/src/test/java/net/helenus/test/integration/core/unitofwork/UnitOfWorkTest.java +++ b/src/test/java/net/helenus/test/integration/core/unitofwork/UnitOfWorkTest.java @@ -17,12 +17,12 @@ package net.helenus.test.integration.core.unitofwork; import static net.helenus.core.Query.eq; -import ca.exprofesso.guava.jcache.GuavaCachingProvider; -import com.datastax.driver.core.ConsistencyLevel; -import com.datastax.driver.core.utils.UUIDs; + import java.io.Serializable; import java.util.Date; +import java.util.Map; import java.util.UUID; +import java.util.Set; import javax.cache.CacheManager; import javax.cache.Caching; import javax.cache.configuration.MutableConfiguration; @@ -42,6 +42,11 @@ import net.helenus.test.integration.build.AbstractEmbeddedCassandraTest; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import ca.exprofesso.guava.jcache.GuavaCachingProvider; +import com.datastax.driver.core.ConsistencyLevel; +import com.datastax.driver.core.utils.UUIDs; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; @Table @Cacheable @@ -550,4 +555,163 @@ public class UnitOfWorkTest extends AbstractEmbeddedCassandraTest { Assert.assertTrue(w1.equals(w4)); Assert.assertTrue(w4.equals(w1)); } + + @Test + public void getAllLoadAllTest() throws Exception { + String tableName = MappingUtil.getTableName(Widget.class, false).toString(); + UUID uuid1 = UUIDs.random(); + UUID uuid2 = UUIDs.random(); + UUID uuid3 = UUIDs.random(); + String k1 = tableName + "." + uuid1.toString(); + String k2 = tableName + "." + uuid2.toString(); + String k3 = tableName + "." + uuid3.toString(); + Set allKeys = ImmutableSet.of(k1, k2, k3); + + try (UnitOfWork uow1 = session.begin()) { + Widget w1 = session.insert(widget).value(widget::id, uuid1).sync(uow1); + Widget w2 = session.insert(widget).value(widget::id, uuid2).sync(uow1); + uow1.getCache().put(k1, w1); + uow1.getCache().put(k2, w2); + + Map results = uow1.getCache().getAll(allKeys); + Assert.assertEquals(2, results.entrySet().size()); + Assert.assertEquals(results, ImmutableMap.of(k1, w1, k2, w2)); + + // getAll tests + try (UnitOfWork uow2 = session.begin(uow1)) { + results = uow2.getCache().getAll(allKeys); + Assert.assertEquals(2, results.entrySet().size()); + Assert.assertEquals(results, ImmutableMap.of(k1, w1, k2, w2)); + + Widget w3 = session.insert(widget).value(widget::id, uuid3).sync(uow2); + uow2.getCache().put(k3, w3); + results = uow2.getCache().getAll(allKeys); + Assert.assertEquals(3, results.entrySet().size()); + Assert.assertEquals(results, ImmutableMap.of(k1, w1, k2, w2, k3, w3)); + + boolean removed = uow2.getCache().remove(k2); + Assert.assertTrue(removed); + removed = uow2.getCache().remove(k2); + Assert.assertFalse(removed); + results = uow2.getCache().getAll(allKeys); + Assert.assertEquals(2, results.size()); + Assert.assertEquals(results, ImmutableMap.of(k1, w1, k3, w3)); + + // Propagate changes to parent UOW for below tests. + uow2.commit(); + } + + // loadAll tests + try (UnitOfWork uow3 = session.begin(uow1)) { + uow3.getCache().loadAll(allKeys, false, null); + Assert.assertTrue(uow3.getCache().containsKey(k1)); + Assert.assertTrue(uow3.getCache().containsKey(k3)); + Assert.assertFalse(uow3.getCache().containsKey(k2)); + Assert.assertEquals(w1, uow3.getCache().get(k1)); + } + + try (UnitOfWork uow4 = session.begin(uow1)) { + UUID uuid3Updated = UUIDs.random(); + Widget w3Updated = session.insert(widget).value(widget::id, uuid3Updated).sync(uow4); + + // Insert a value for a known key, and load the cache without replacing existing values + uow4.getCache().put(k3, w3Updated); + Assert.assertEquals(w3Updated, uow4.getCache().get(k3)); + uow4.getCache().loadAll(allKeys, false, null); + Assert.assertEquals(w3Updated, uow4.getCache().get(k3)); + + // Insert a value for a known key, and load the cache by replacing existing values + UnitOfWork uow5 = session.begin(uow1); + uow5.getCache().put(k3, w3Updated); + Assert.assertEquals(w3Updated, uow5.getCache().get(k3)); + uow5.getCache().loadAll(allKeys, true, null); + Assert.assertNotNull(uow5.getCache().get(k3)); + Assert.assertNotEquals(w3Updated, uow5.getCache().get(k3)); + } + } + } + + @Test + public void getAndPutTest() throws Exception { + String tableName = MappingUtil.getTableName(Widget.class, false).toString(); + UUID uuid1 = UUIDs.random(); + UUID uuid2 = UUIDs.random(); + String k1 = tableName + "." + uuid1.toString(); + + try (UnitOfWork uow1 = session.begin()) { + Widget w1 = session.insert(widget).value(widget::id, uuid1).sync(uow1); + uow1.getCache().put(k1, w1); + try (UnitOfWork uow2 = session.begin(uow1)) { + Widget w2 = session.insert(widget).value(widget::id, uuid2).sync(uow2); + Widget value = (Widget) uow2.getCache().getAndPut(k1, w2); + Assert.assertEquals(w1, value); + value = (Widget) uow2.getCache().get(k1); + Assert.assertEquals(w2, value); + } + } + } + + @Test + public void removeAllTest() throws Exception { + String tableName = MappingUtil.getTableName(Widget.class, false).toString(); + UUID uuid1 = UUIDs.random(); + UUID uuid2 = UUIDs.random(); + String k1 = tableName + "." + uuid1.toString(); + String k2 = tableName + "." + uuid2.toString(); + Set keys = ImmutableSet.of(k1, k2, "noValue"); + + try (UnitOfWork uow = session.begin()) { + Widget w1 = session.insert(widget).value(widget::id, uuid1).sync(uow); + Widget w2 = session.insert(widget).value(widget::id, uuid2).sync(uow); + uow.getCache().put(k1, w1); + uow.getCache().put(k2, w2); + uow.getCache().removeAll(keys); + } + } + + @Test + public void testDeleteInNestedUOW() throws Exception { + String tableName = MappingUtil.getTableName(Widget.class, false).toString(); + UUID uuid1 = UUIDs.random(); + UUID uuid2 = UUIDs.random(); + String k1 = tableName + "." + uuid1.toString(); + String k2 = tableName + "." + uuid2.toString(); + + try (UnitOfWork uow1 = session.begin()) { + Widget w1 = session.insert(widget).value(widget::id, uuid1) + .value(widget::name, RandomString.make(10)) + .sync(uow1); + Widget w2 = session.insert(widget).value(widget::id, uuid2) + .value(widget::name, RandomString.make(20)) + .sync(uow1); + uow1.getCache().put(k1, w1); + uow1.getCache().put(k2, w2); + + try (UnitOfWork uow2 = session.begin(uow1)) { + Object o1 = uow2.getCache().get(k1); + Object o2 = uow2.getCache().get(k2); + Assert.assertEquals(w1, o1); + Assert.assertEquals(w2, o2); + + // k1 should not be available in uow2, but available in uow1. + uow2.getCache().remove(k1); + Assert.assertNull(uow2.getCache().get(k1)); + Assert.assertNotNull(uow1.getCache().get(k1)); + + // Post-commit, k1 shouldn't be availble in uow1 either + uow2.commit(); + Assert.assertNull(uow2.getCache().get(k1)); + Assert.assertNull(uow1.getCache().get(k1)); + + try (UnitOfWork uow3 = session.begin(uow2)) { + uow3.getCache().get(k1); + uow3.getCache().get(k2); + uow3.getCache().remove(k2); + } + } + + uow1.getCache().get(k1); + uow1.getCache().get(k2); + } + } }