diff --git a/query-algebrizer/src/clauses/mod.rs b/query-algebrizer/src/clauses/mod.rs index 8b58526c..bfa04102 100644 --- a/query-algebrizer/src/clauses/mod.rs +++ b/query-algebrizer/src/clauses/mod.rs @@ -648,6 +648,19 @@ impl ConjoiningClauses { } } + /// Eliminate any type extractions for variables whose types are definitely known. + pub fn prune_extracted_types(&mut self) { + if self.extracted_types.is_empty() || self.known_types.is_empty() { + return; + } + for (var, types) in self.known_types.iter() { + if types.len() == 1 { + self.extracted_types.remove(var); + } + } + } + + /// When a CC has accumulated all patterns, generate value_type_tag entries in `wheres` /// to refine value types for which two things are true: /// diff --git a/query-algebrizer/src/clauses/pattern.rs b/query-algebrizer/src/clauses/pattern.rs index a72140e1..6b4f8455 100644 --- a/query-algebrizer/src/clauses/pattern.rs +++ b/query-algebrizer/src/clauses/pattern.rs @@ -265,6 +265,8 @@ impl ConjoiningClauses { #[cfg(test)] mod testing { + extern crate mentat_query_parser; + use super::*; use std::collections::BTreeMap; @@ -281,6 +283,10 @@ mod testing { Variable, }; + use self::mentat_query_parser::{ + parse_find_string, + }; + use clauses::{ add_attribute, associate_ident, @@ -296,6 +302,13 @@ mod testing { SourceAlias, }; + use algebrize; + + fn alg(schema: &Schema, input: &str) -> ConjoiningClauses { + let parsed = parse_find_string(input).expect("parse failed"); + algebrize(schema.into(), parsed).expect("algebrize failed").cc + } + #[test] fn test_unknown_ident() { let mut cc = ConjoiningClauses::default(); @@ -815,4 +828,22 @@ mod testing { assert_eq!(cc.empty_because.unwrap(), EmptyBecause::TypeMismatch(x.clone(), unit_type_set(ValueType::Ref), ValueType::Boolean)); } + + #[test] + fn ensure_extracted_types_is_cleared() { + let query = r#"[:find ?e ?v :where [_ _ ?v] [?e :foo/bar ?v]]"#; + let mut schema = Schema::default(); + associate_ident(&mut schema, NamespacedKeyword::new("foo", "bar"), 99); + add_attribute(&mut schema, 99, Attribute { + value_type: ValueType::Boolean, + ..Default::default() + }); + let e = Variable::from_valid_name("?e"); + let v = Variable::from_valid_name("?v"); + let cc = alg(&schema, query); + assert_eq!(cc.known_types.get(&e), Some(&unit_type_set(ValueType::Ref))); + assert_eq!(cc.known_types.get(&v), Some(&unit_type_set(ValueType::Boolean))); + assert!(!cc.extracted_types.contains_key(&e)); + assert!(!cc.extracted_types.contains_key(&v)); + } } diff --git a/query-algebrizer/src/lib.rs b/query-algebrizer/src/lib.rs index 3cf044e9..597d81f1 100644 --- a/query-algebrizer/src/lib.rs +++ b/query-algebrizer/src/lib.rs @@ -78,6 +78,7 @@ pub fn algebrize(schema: &Schema, parsed: FindQuery) -> Result { cc.apply_clause(schema, where_clause)?; } cc.expand_column_bindings(); + cc.prune_extracted_types(); let limit = if parsed.find_spec.is_unit_limited() { Some(1) } else { None }; Ok(AlgebraicQuery {