diff --git a/query-pull/src/errors.rs b/query-pull/src/errors.rs index f0e747e6..94c38426 100644 --- a/query-pull/src/errors.rs +++ b/query-pull/src/errors.rs @@ -22,6 +22,11 @@ error_chain! { description("unnamed attribute") display("attribute {:?} has no name", id) } + + RepeatedDbId { + description(":db/id repeated") + display(":db/id repeated") + } } links { diff --git a/query-pull/src/lib.rs b/query-pull/src/lib.rs index b6f96030..c873bd4f 100644 --- a/query-pull/src/lib.rs +++ b/query-pull/src/lib.rs @@ -79,12 +79,14 @@ use std::iter::{ }; use mentat_core::{ + Binding, Cloned, Entid, HasSchema, NamespacedKeyword, Schema, StructuredMap, + TypedValue, ValueRc, }; @@ -143,6 +145,7 @@ pub struct Puller { // The range is the set of aliases to use in the output. attributes: BTreeMap>, attribute_spec: cache::AttributeSpec, + db_id_alias: Option>, } impl Puller { @@ -166,6 +169,9 @@ impl Puller { let mut names: BTreeMap> = Default::default(); let mut attrs: BTreeSet = Default::default(); + let db_id = ::std::rc::Rc::new(NamespacedKeyword::new("db", "id")); + let mut db_id_alias = None; + for attr in attributes.iter() { match attr { &PullAttributeSpec::Wildcard => { @@ -183,6 +189,14 @@ impl Puller { let alias = alias.as_ref() .map(|ref r| r.to_value_rc()); match attribute { + // Handle :db/id. + &PullConcreteAttribute::Ident(ref i) if i.as_ref() == db_id.as_ref() => { + // We only allow :db/id once. + if db_id_alias.is_some() { + bail!(ErrorKind::RepeatedDbId); + } + db_id_alias = Some(alias.unwrap_or_else(|| db_id.to_value_rc())); + }, &PullConcreteAttribute::Ident(ref i) => { if let Some(entid) = schema.get_entid(i) { let name = alias.unwrap_or_else(|| i.to_value_rc()); @@ -203,6 +217,7 @@ impl Puller { Ok(Puller { attributes: names, attribute_spec: cache::AttributeSpec::specified(&attrs, schema), + db_id_alias, }) } @@ -234,6 +249,17 @@ impl Puller { // TODO: should we walk `e` then `a`, or `a` then `e`? Possibly the right answer // is just to collect differently! let mut maps = BTreeMap::new(); + + // Collect :db/id if requested. + if let Some(ref alias) = self.db_id_alias { + for e in entities.iter() { + let mut r = maps.entry(*e) + .or_insert(ValueRc::new(StructuredMap::default())); + let mut m = ValueRc::get_mut(r).unwrap(); + m.insert(alias.clone(), Binding::Scalar(TypedValue::Ref(*e))); + } + } + for (name, cache) in self.attributes.iter().filter_map(|(a, name)| caches.forward_attribute_cache_for_attribute(schema, *a) .map(|cache| (name.clone(), cache))) { diff --git a/tests/pull.rs b/tests/pull.rs index b33952a3..4e2d70ba 100644 --- a/tests/pull.rs +++ b/tests/pull.rs @@ -116,7 +116,9 @@ fn test_simple_pull() { assert_eq!(pulled, expected); // Now test pull inside the query itself. - let query = r#"[:find ?hood (pull ?district [[:district/name :as :district/district] :district/region]) + let query = r#"[:find ?hood (pull ?district [:db/id + [:district/name :as :district/district] + :district/region]) :where (or [?hood :neighborhood/name "Beacon Hill"] [?hood :neighborhood/name "Capitol Hill"]) @@ -127,22 +129,24 @@ fn test_simple_pull() { .into_rel_result() .expect("results"); - let beacon_district: Vec<(NamespacedKeyword, TypedValue)> = vec![ + let beacon_district_pull: Vec<(NamespacedKeyword, TypedValue)> = vec![ + (kw!(:db/id), TypedValue::Ref(beacon_district)), (kw!(:district/district), "Greater Duwamish".into()), (kw!(:district/region), schema.get_entid(&NamespacedKeyword::new("region", "se")).unwrap().into()) ]; - let beacon_district: StructuredMap = beacon_district.into(); - let capitol_district: Vec<(NamespacedKeyword, TypedValue)> = vec![ + let beacon_district_pull: StructuredMap = beacon_district_pull.into(); + let capitol_district_pull: Vec<(NamespacedKeyword, TypedValue)> = vec![ + (kw!(:db/id), TypedValue::Ref(capitol_district)), (kw!(:district/district), "East".into()), (kw!(:district/region), schema.get_entid(&NamespacedKeyword::new("region", "e")).unwrap().into()) ]; - let capitol_district: StructuredMap = capitol_district.into(); + let capitol_district_pull: StructuredMap = capitol_district_pull.into(); let expected = RelResult { width: 2, values: vec![ - TypedValue::Ref(capitol).into(), capitol_district.into(), - TypedValue::Ref(beacon).into(), beacon_district.into(), + TypedValue::Ref(capitol).into(), capitol_district_pull.into(), + TypedValue::Ref(beacon).into(), beacon_district_pull.into(), ].into(), }; assert_eq!(results, expected.clone()); @@ -158,14 +162,19 @@ fn test_simple_pull() { // Execute a scalar query where the body is constant. // TODO: we shouldn't require `:where`; that makes this non-constant! - let query = r#"[:find (pull ?hood [:neighborhood/name]) . :in ?hood + let query = r#"[:find (pull ?hood [[:db/id :as :neighborhood/id] + :neighborhood/name]) . + :in ?hood :where [?hood :neighborhood/district _]]"#; let result = reader.q_once(query, QueryInputs::with_value_sequence(vec![(var!(?hood), TypedValue::Ref(beacon))])) .into_scalar_result() .expect("success") .expect("result"); - let expected: StructuredMap = vec![(kw!(:neighborhood/name), TypedValue::from("Beacon Hill"))].into(); + let expected: StructuredMap = vec![ + (kw!(:neighborhood/name), TypedValue::from("Beacon Hill")), + (kw!(:neighborhood/id), TypedValue::Ref(beacon)), + ].into(); assert_eq!(result, expected.into()); // Collect the names and regions of all districts.