diff --git a/db/src/cache.rs b/db/src/cache.rs index bcb3c12f..2e0ba742 100644 --- a/db/src/cache.rs +++ b/db/src/cache.rs @@ -1369,6 +1369,14 @@ impl<'a> InProgressCacheTransactWatcher<'a> { } impl<'a> TransactWatcher for InProgressCacheTransactWatcher<'a> { + type Result = (); + + + + fn tx_id(&mut self) -> Option { + None + } + fn datom(&mut self, op: OpType, e: Entid, a: Entid, v: &TypedValue) { if !self.active { return; @@ -1402,7 +1410,7 @@ impl<'a> TransactWatcher for InProgressCacheTransactWatcher<'a> { } } - fn done(&mut self, schema: &Schema) -> Result<()> { + fn done(&mut self, _t: &Entid, schema: &Schema) -> Result { // Oh, I wish we had impl trait. Without it we have a six-line type signature if we // try to break this out as a helper function. let collected_retractions = mem::replace(&mut self.collected_retractions, Default::default()); diff --git a/db/src/lib.rs b/db/src/lib.rs index debed886..35afa039 100644 --- a/db/src/lib.rs +++ b/db/src/lib.rs @@ -89,6 +89,7 @@ pub use tx::{ }; pub use tx_observer::{ + InProgressObserverTransactWatcher, TxObservationService, TxObserver, }; diff --git a/db/src/tx.rs b/db/src/tx.rs index 20d35bfe..11a22e81 100644 --- a/db/src/tx.rs +++ b/db/src/tx.rs @@ -714,7 +714,7 @@ impl<'conn, 'a, W> Tx<'conn, 'a, W> where W: TransactWatcher { } db::update_partition_map(self.store, &self.partition_map)?; - self.watcher.done(self.schema)?; + self.watcher.done(&self.tx_id, self.schema)?; if tx_might_update_metadata { // Extract changes to metadata from the store. @@ -739,7 +739,6 @@ impl<'conn, 'a, W> Tx<'conn, 'a, W> where W: TransactWatcher { tx_id: self.tx_id, tx_instant, tempids: tempids, - changeset: affected_attrs, }) } } @@ -752,7 +751,6 @@ fn start_tx<'conn, 'a, W>(conn: &'conn rusqlite::Connection, watcher: W) -> Result> where W: TransactWatcher { let tx_id = partition_map.allocate_entid(":db.part/tx"); - conn.begin_tx_application()?; Ok(Tx::new(conn, partition_map, schema_for_mutation, schema, watcher, tx_id)) diff --git a/db/src/tx_observer.rs b/db/src/tx_observer.rs index 456bd51a..d5444b34 100644 --- a/db/src/tx_observer.rs +++ b/db/src/tx_observer.rs @@ -8,8 +8,12 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. +use std::collections::{ + BTreeMap, +}; use std::sync::{ Arc, + Mutex, Weak, }; use std::sync::mpsc::{ @@ -24,37 +28,49 @@ use indexmap::{ IndexMap, }; -use smallvec::{ - SmallVec, +use mentat_core::{ + Entid, + Schema, + TypedValue, +}; +use mentat_tx::entities::{ + OpType, }; -use types::{ - AttributeSet, - TxReport, +use errors::{ + Result, }; +use types::{ + AccumulatedTxids, + AttributeSet, +}; +use watcher::TransactWatcher; pub struct TxObserver { - notify_fn: Arc) + Send + Sync>>, + notify_fn: Arc) + Send + Sync>>, attributes: AttributeSet, } impl TxObserver { - pub fn new(attributes: AttributeSet, notify_fn: F) -> TxObserver where F: Fn(&str, SmallVec<[&TxReport; 4]>) + 'static + Send + Sync { + pub fn new(attributes: AttributeSet, notify_fn: F) -> TxObserver where F: Fn(&str, BTreeMap<&Entid, &AttributeSet>) + 'static + Send + Sync { TxObserver { notify_fn: Arc::new(Box::new(notify_fn)), attributes, } } - pub fn applicable_reports<'r>(&self, reports: &'r SmallVec<[TxReport; 4]>) -> SmallVec<[&'r TxReport; 4]> { - reports.into_iter().filter_map(|report| { - self.attributes.intersection(&report.changeset) + pub fn applicable_reports<'r>(&self, reports: &'r BTreeMap) -> BTreeMap<&'r Entid, &'r AttributeSet> { + reports.into_iter().filter_map(|(txid, changeset)| { + self.attributes.intersection(changeset) .next() - .and_then(|_| Some(report)) - }).collect() + .and_then(|_| Some((txid, changeset))) + }).fold(BTreeMap::new(), |mut map, (txid, changeset)| { + map.insert(txid, changeset); + map + }) } - fn notify(&self, key: &str, reports: SmallVec<[&TxReport; 4]>) { + fn notify(&self, key: &str, reports: BTreeMap<&Entid, &AttributeSet>) { (*self.notify_fn)(key, reports); } } @@ -64,12 +80,12 @@ pub trait Command { } pub struct TxCommand { - reports: SmallVec<[TxReport; 4]>, + reports: BTreeMap, observers: Weak>>, } impl TxCommand { - fn new(observers: &Arc>>, reports: SmallVec<[TxReport; 4]>) -> Self { + fn new(observers: &Arc>>, reports: BTreeMap) -> Self { TxCommand { reports, observers: Arc::downgrade(observers), @@ -92,16 +108,16 @@ impl Command for TxCommand { pub struct TxObservationService { observers: Arc>>, + transactions: BTreeMap, executor: Option>>, - in_progress_count: i32, } impl TxObservationService { pub fn new() -> Self { TxObservationService { observers: Arc::new(IndexMap::new()), + transactions: Default::default(), executor: None, - in_progress_count: 0, } } @@ -122,49 +138,96 @@ impl TxObservationService { !self.observers.is_empty() } - pub fn transaction_did_start(&mut self) { - self.in_progress_count += 1; + pub fn add_transaction(&mut self, tx_id: Entid, attributes: AttributeSet) { + self.transactions.insert(tx_id, attributes); } - pub fn transaction_did_commit(&mut self, reports: SmallVec<[TxReport; 4]>) { - { - let executor = self.executor.get_or_insert_with(||{ - let (tx, rx): (Sender>, Receiver>) = channel(); - let mut worker = CommandExecutor::new(rx); + pub fn transaction_did_commit(&mut self, txids: &AccumulatedTxids) { + // collect the changesets relating to this commit + let reports: BTreeMap = txids.into_iter().filter_map(|tx_id| { + self.transactions.remove(&tx_id).map_or(None, |changeset| Some((tx_id, changeset))) + }) + .fold(BTreeMap::new(), |mut map, (tx_id, changeset)| { + map.insert(*tx_id, changeset); + map + }); - thread::spawn(move || { - worker.main(); - }); + let executor = self.executor.get_or_insert_with(||{ + let (tx, rx): (Sender>, Receiver>) = channel(); + let mut worker = CommandExecutor::new(rx); - tx + thread::spawn(move || { + worker.main(); }); - let cmd = Box::new(TxCommand::new(&self.observers, reports)); - executor.send(cmd).unwrap(); - } + tx + }); - self.in_progress_count -= 1; + let cmd = Box::new(TxCommand::new(&self.observers, reports)); + executor.send(cmd).unwrap(); + } +} - if self.in_progress_count == 0 { - self.executor = None; +impl Drop for TxObservationService { + fn drop(&mut self) { + self.executor = None; + } +} + +pub struct InProgressObserverTransactWatcher<'a> { + collected_datoms: AttributeSet, + observer_service: &'a Mutex, + active: bool +} + +impl<'a> InProgressObserverTransactWatcher<'a> { + pub fn new(observer_service: &'a Mutex) -> InProgressObserverTransactWatcher { + let mut w = InProgressObserverTransactWatcher { + collected_datoms: Default::default(), + observer_service, + active: true + }; + + w.active = observer_service.lock().unwrap().has_observers(); + w + } +} + +impl<'a> TransactWatcher for InProgressObserverTransactWatcher<'a> { + type Result = (); + + fn tx_id(&mut self) -> Option { + None + } + + fn datom(&mut self, _op: OpType, _e: Entid, a: Entid, _v: &TypedValue) { + if !self.active { + return } + self.collected_datoms.insert(a); + } + + fn done(&mut self, t: &Entid, _schema: &Schema) -> Result { + let collected_datoms = ::std::mem::replace(&mut self.collected_datoms, Default::default()); + self.observer_service.lock().unwrap().add_transaction(t.clone(), collected_datoms); + Ok(()) } } struct CommandExecutor { - reciever: Receiver>, + receiver: Receiver>, } impl CommandExecutor { fn new(rx: Receiver>) -> Self { CommandExecutor { - reciever: rx, + receiver: rx, } } fn main(&mut self) { loop { - match self.reciever.recv() { + match self.receiver.recv() { Err(RecvError) => { eprintln!("Disconnected, terminating CommandExecutor"); return diff --git a/db/src/types.rs b/db/src/types.rs index 5e19da73..f91642ff 100644 --- a/db/src/types.rs +++ b/db/src/types.rs @@ -16,6 +16,8 @@ use std::collections::{ BTreeSet, }; +use smallvec::SmallVec; + extern crate mentat_core; pub use self::mentat_core::{ @@ -88,6 +90,8 @@ pub type AVMap<'a> = HashMap<&'a AVPair, Entid>; // represents a set of entids that are correspond to attributes pub type AttributeSet = BTreeSet; +pub type AccumulatedTxids = SmallVec<[Entid; 4]>; + /// A transaction report summarizes an applied transaction. #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq)] pub struct TxReport { @@ -103,7 +107,4 @@ pub struct TxReport { /// existing entid, or is allocated a new entid. (It is possible for multiple distinct string /// literal tempids to all unify to a single freshly allocated entid.) pub tempids: BTreeMap, - - // A set of entids for attributes that were affected inside this transaction - pub changeset: AttributeSet, } diff --git a/db/src/watcher.rs b/db/src/watcher.rs index fa22cb6c..d8ce9302 100644 --- a/db/src/watcher.rs +++ b/db/src/watcher.rs @@ -32,22 +32,33 @@ use errors::{ }; pub trait TransactWatcher { + type Result; + + fn tx_id(&mut self) -> Option; + fn datom(&mut self, op: OpType, e: Entid, a: Entid, v: &TypedValue); /// Only return an error if you want to interrupt the transact! /// Called with the schema _prior to_ the transact -- any attributes or /// attribute changes transacted during this transact are not reflected in /// the schema. - fn done(&mut self, schema: &Schema) -> Result<()>; + fn done(&mut self, t: &Entid, schema: &Schema) -> Result; } pub struct NullWatcher(); impl TransactWatcher for NullWatcher { + type Result = (); + + + fn tx_id(&mut self) -> Option { + None + } + fn datom(&mut self, _op: OpType, _e: Entid, _a: Entid, _v: &TypedValue) { } - fn done(&mut self, _schema: &Schema) -> Result<()> { + fn done(&mut self, _t: &Entid, _schema: &Schema) -> Result { Ok(()) } } diff --git a/src/conn.rs b/src/conn.rs index bf407738..c8f343a8 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -32,8 +32,6 @@ use rusqlite::{ TransactionBehavior, }; -use smallvec::SmallVec; - use edn; use mentat_core::{ @@ -50,6 +48,7 @@ use mentat_core::{ use mentat_core::intern_set::InternSet; use mentat_db::cache::{ + InProgressCacheTransactWatcher, InProgressSQLiteAttributeCache, SQLiteAttributeCache, }; @@ -58,17 +57,26 @@ use mentat_db::db; use mentat_db::{ transact, transact_terms, + InProgressObserverTransactWatcher, PartitionMap, + TransactWatcher, TxObservationService, TxObserver, TxReport, }; +use mentat_db::types::{ + AccumulatedTxids, +}; + use mentat_db::internal_types::TermWithTempIds; use mentat_tx; -use mentat_tx::entities::TempId; +use mentat_tx::entities::{ + TempId, + OpType, +}; use mentat_tx_parser; @@ -209,9 +217,8 @@ pub struct InProgress<'a, 'c> { schema: Schema, cache: InProgressSQLiteAttributeCache, use_caching: bool, - // TODO: Collect txids/affected datoms in a better way - tx_reports: SmallVec<[TxReport; 4]>, - observer_service: Option<&'a Mutex>, + tx_ids: AccumulatedTxids, + tx_observer: &'a Mutex, } /// Represents an in-progress set of reads to the store. Just like `InProgress`, @@ -372,15 +379,20 @@ impl<'a, 'c> InProgress<'a, 'c> { } pub fn transact_terms(&mut self, terms: I, tempid_set: InternSet) -> Result where I: IntoIterator { - let (report, next_partition_map, next_schema, _watcher) = + let w = InProgressTransactWatcher::new( + InProgressObserverTransactWatcher::new(self.tx_observer), + self.cache.transact_watcher()); + let (report, next_partition_map, next_schema, mut watcher) = transact_terms(&self.transaction, self.partition_map.clone(), &self.schema, &self.schema, - self.cache.transact_watcher(), + w, terms, tempid_set)?; - self.tx_reports.push(report.clone()); + if let Some(tx_id) = watcher.tx_id() { + self.tx_ids.push(tx_id); + } self.partition_map = next_partition_map; if let Some(schema) = next_schema { self.schema = schema; @@ -397,14 +409,20 @@ impl<'a, 'c> InProgress<'a, 'c> { // `Metadata` on return. If we used `Cell` or other mechanisms, we'd be using // `Default::default` in those situations to extract the partition map, and so there // would still be some cost. - let (report, next_partition_map, next_schema, _watcher) = + let w = InProgressTransactWatcher::new( + InProgressObserverTransactWatcher::new(self.tx_observer), + self.cache.transact_watcher()); + let (report, next_partition_map, next_schema, mut watcher) = transact(&self.transaction, self.partition_map.clone(), &self.schema, - &self.schema, - self.cache.transact_watcher(), + & + self.schema, + w, entities)?; - self.tx_reports.push(report.clone()); + if let Some(tx_id) = watcher.tx_id() { + self.tx_ids.push(tx_id); + } self.partition_map = next_partition_map; if let Some(schema) = next_schema { @@ -460,11 +478,7 @@ impl<'a, 'c> InProgress<'a, 'c> { // TODO: consider making vocabulary lookup lazy -- we won't need it much of the time. } - // let the transaction observer know that there have been some transactions committed. - if let Some(ref observer_service) = self.observer_service { - let mut os = observer_service.lock().unwrap(); - os.transaction_did_commit(self.tx_reports); - } + self.tx_observer.lock().unwrap().transaction_did_commit(&self.tx_ids); Ok(()) } @@ -493,6 +507,42 @@ impl<'a, 'c> InProgress<'a, 'c> { } } +struct InProgressTransactWatcher<'a> { + cache_watcher: InProgressCacheTransactWatcher<'a>, + observer_watcher: InProgressObserverTransactWatcher<'a>, + tx_id: Option, +} + +impl<'a> InProgressTransactWatcher<'a> { + fn new(observer_watcher: InProgressObserverTransactWatcher<'a>, cache_watcher: InProgressCacheTransactWatcher<'a>) -> Self { + InProgressTransactWatcher { + cache_watcher: cache_watcher, + observer_watcher: observer_watcher, + tx_id: None, + } + } +} + +impl<'a> TransactWatcher for InProgressTransactWatcher<'a> { + type Result = (); + + fn tx_id(&mut self) -> Option { + self.tx_id.take() + } + + fn datom(&mut self, op: OpType, e: Entid, a: Entid, v: &TypedValue) { + self.cache_watcher.datom(op.clone(), e.clone(), a.clone(), v); + self.observer_watcher.datom(op.clone(), e.clone(), a.clone(), v); + } + + fn done(&mut self, t: &Entid, schema: &Schema) -> ::mentat_db::errors::Result { + self.cache_watcher.done(t, schema)?; + self.observer_watcher.done(t, schema)?; + self.tx_id = Some(t.clone()); + Ok(()) + } +} + impl Store { /// Intended for use from tests. pub fn sqlite_mut(&mut self) -> &mut rusqlite::Connection { @@ -718,14 +768,6 @@ impl Conn { current.attribute_cache.clone()) }; - let mut obs = self.tx_observer_service.lock().unwrap(); - let observer_service = if obs.has_observers() { - obs.transaction_did_start(); - Some(&self.tx_observer_service) - } else { - None - }; - Ok(InProgress { mutex: &self.metadata, transaction: tx, @@ -734,8 +776,8 @@ impl Conn { schema: (*current_schema).clone(), cache: InProgressSQLiteAttributeCache::from_cache(cache_cow), use_caching: true, - tx_reports: SmallVec::new(), - observer_service: observer_service, + tx_ids: Default::default(), + tx_observer: &self.tx_observer_service, }) } @@ -840,9 +882,10 @@ mod tests { use std::path::{ PathBuf, }; + use std::sync::mpsc; use std::time::{ Duration, - Instant + Instant, }; use mentat_core::{ @@ -872,7 +915,7 @@ mod tests { }; use ::vocabulary::attribute::{ - Unique + Unique, }; use mentat_db::USER0; @@ -1507,7 +1550,7 @@ mod tests { in_progress.commit().expect("Expected vocabulary committed"); } - #[derive(Default)] + #[derive(Default, Debug)] struct ObserverOutput { txids: Vec, changes: Vec>, @@ -1531,7 +1574,7 @@ mod tests { let output = Arc::new(Mutex::new(ObserverOutput::default())); let mut_output = Arc::downgrade(&output); - let (tx, rx): (::std::sync::mpsc::Sender<()>, ::std::sync::mpsc::Receiver<()>) = ::std::sync::mpsc::channel(); + let (tx, rx): (mpsc::Sender<()>, mpsc::Receiver<()>) = mpsc::channel(); // because the TxObserver is in an Arc and is therefore Sync, we have to wrap the Sender in a Mutex to also // make it Sync. let thread_tx = Mutex::new(tx); @@ -1539,9 +1582,9 @@ mod tests { if let Some(out) = mut_output.upgrade() { let mut o = out.lock().unwrap(); o.called_key = Some(obs_key.to_string()); - for report in batch.iter() { - o.txids.push(report.tx_id.clone()); - o.changes.push(report.changeset.clone()); + for (tx_id, changes) in batch.into_iter() { + o.txids.push(*tx_id); + o.changes.push(changes.clone()); } o.txids.sort(); } @@ -1553,21 +1596,26 @@ mod tests { let mut tx_ids = Vec::new(); let mut changesets = Vec::new(); + let uuid_entid: Entid = conn.current_schema().get_entid(&kw!(:todo/uuid)).expect("entid to exist for name").into(); { let mut in_progress = conn.begin_transaction(&mut sqlite).expect("expected transaction"); for i in 0..3 { + let mut changeset = BTreeSet::new(); let name = format!("todo{}", i); let uuid = Uuid::new_v4(); let mut builder = in_progress.builder().describe_tempid(&name); builder.add_kw(&kw!(:todo/uuid), TypedValue::Uuid(uuid)).expect("Expected added uuid"); + changeset.insert(uuid_entid.clone()); builder.add_kw(&kw!(:todo/name), TypedValue::typed_string(&name)).expect("Expected added name"); + changeset.insert(name_entid.clone()); if i % 2 == 0 { builder.add_kw(&kw!(:todo/completion_date), TypedValue::current_instant()).expect("Expected added date"); + changeset.insert(date_entid.clone()); } let (ip, r) = builder.transact(); let report = r.expect("expected a report"); tx_ids.push(report.tx_id.clone()); - changesets.push(report.changeset.clone()); + changesets.push(changeset); in_progress = ip; } let mut builder = in_progress.builder().describe_tempid("Label"); @@ -1579,18 +1627,11 @@ mod tests { let delay = Duration::from_millis(100); let _ = rx.recv_timeout(delay); - match Arc::try_unwrap(output) { - Ok(out) => { - let o = out.into_inner().expect("Expected an Output"); - assert_eq!(o.called_key, Some(key.clone())); - assert_eq!(o.txids, tx_ids); - assert_eq!(o.changes, changesets); - }, - _ => { - println!("Unable to unwrap output"); - assert!(false); - } - } + let out = Arc::try_unwrap(output).expect("unwrapped"); + let o = out.into_inner().expect("Expected an Output"); + assert_eq!(o.called_key, Some(key.clone())); + assert_eq!(o.txids, tx_ids); + assert_eq!(o.changes, changesets); } #[test] @@ -1610,15 +1651,15 @@ mod tests { let output = Arc::new(Mutex::new(ObserverOutput::default())); let mut_output = Arc::downgrade(&output); - let (tx, rx): (::std::sync::mpsc::Sender<()>, ::std::sync::mpsc::Receiver<()>) = ::std::sync::mpsc::channel(); + let (tx, rx): (mpsc::Sender<()>, mpsc::Receiver<()>) = mpsc::channel(); let thread_tx = Mutex::new(tx); let tx_observer = Arc::new(TxObserver::new(registered_attrs, move |obs_key, batch| { if let Some(out) = mut_output.upgrade() { let mut o = out.lock().unwrap(); o.called_key = Some(obs_key.to_string()); - for report in batch.iter() { - o.txids.push(report.tx_id.clone()); - o.changes.push(report.changeset.clone()); + for (tx_id, changes) in batch.into_iter() { + o.txids.push(*tx_id); + o.changes.push(changes.clone()); } o.txids.sort(); } @@ -1645,17 +1686,10 @@ mod tests { let delay = Duration::from_millis(100); let _ = rx.recv_timeout(delay); - match Arc::try_unwrap(output) { - Ok(out) => { - let o = out.into_inner().expect("Expected an Output"); - assert_eq!(o.called_key, None); - assert_eq!(o.txids, tx_ids); - assert_eq!(o.changes, changesets); - }, - _ => { - println!("Unable to unwrap output"); - assert!(false); - } - } + let out = Arc::try_unwrap(output).expect("unwrapped"); + let o = out.into_inner().expect("Expected an Output"); + assert_eq!(o.called_key, None); + assert_eq!(o.txids, tx_ids); + assert_eq!(o.changes, changesets); } }