diff --git a/query-algebrizer/src/clauses/mod.rs b/query-algebrizer/src/clauses/mod.rs index 6d1856af..3b1a5c6e 100644 --- a/query-algebrizer/src/clauses/mod.rs +++ b/query-algebrizer/src/clauses/mod.rs @@ -83,6 +83,8 @@ use validate::{ validate_or_join, }; +use self::predicate::parse_type_predicate; + pub use self::inputs::QueryInputs; // We do this a lot for errors. @@ -366,7 +368,7 @@ impl ConjoiningClauses { // Are we also trying to figure out the type of the value when the query runs? // If so, constrain that! if let Some(qa) = self.extracted_types.get(&var) { - self.wheres.add_intersection(ColumnConstraint::has_type(qa.0.clone(), vt)); + self.wheres.add_intersection(ColumnConstraint::has_unit_type(qa.0.clone(), vt)); } // Finally, store the binding for future use. @@ -406,14 +408,6 @@ impl ConjoiningClauses { self.known_types.get(var).cloned().unwrap_or(ValueTypeSet::any()) } - fn required_type_set(&self, var: &Variable) -> ValueTypeSet { - self.required_types.get(var).cloned().unwrap_or(ValueTypeSet::any()) - } - - fn possible_type_set(&self, var: &Variable) -> ValueTypeSet { - self.known_type_set(var).intersection(&self.required_type_set(var)) - } - pub fn bind_column_to_var>(&mut self, schema: &Schema, table: TableAlias, column: C, var: Variable) { let column = column.into(); // Do we have an external binding for this? @@ -733,12 +727,13 @@ impl ConjoiningClauses { // TODO: see if the variable is projected, aggregated, or compared elsewhere in // the query. If it's not, we don't need to use all_datoms here. &PatternValuePlace::Variable(ref v) => { - // Do we know that this variable can't be a string? If so, we don't need - // AllDatoms. - if !self.possible_type_set(v).contains(ValueType::String) { - DatomsTable::Datoms - } else { + // If `required_types` and `known_types` don't exclude strings, + // we need to query `all_datoms`. + if self.required_types.get(v).map_or(true, |s| s.contains(ValueType::String)) && + self.known_types.get(v).map_or(true, |s| s.contains(ValueType::String)) { DatomsTable::AllDatoms + } else { + DatomsTable::Datoms } } &PatternValuePlace::Constant(NonIntegerConstant::Text(_)) => @@ -905,8 +900,8 @@ impl ConjoiningClauses { // the variable could take, then we know we're empty. empty_because = Some(EmptyBecause::TypeMismatch { var: var.clone(), - existing: already_known.clone(), - desired: types.clone(), + existing: *already_known, + desired: *types, }); break; } @@ -973,6 +968,26 @@ impl ConjoiningClauses { } impl ConjoiningClauses { + pub fn apply_clauses(&mut self, schema: &Schema, where_clauses: Vec) -> Result<()> { + let mut deferred = Vec::with_capacity(where_clauses.len()); + // We apply (top level) type predicates first as an optimization. + for c in where_clauses { + match &c { + &WhereClause::Pred(ref p) => { + if let Some(ty) = parse_type_predicate(p.operator.0.as_str()) { + self.apply_type_requirement(p, ty)?; + } + }, + _ => {} + }; + deferred.push(c); + } + // Then we apply everything else. + for c in deferred { + self.apply_clause(schema, c)?; + } + Ok(()) + } // This is here, rather than in `lib.rs`, because it's recursive: `or` can contain `or`, // and so on. pub fn apply_clause(&mut self, schema: &Schema, where_clause: WhereClause) -> Result<()> { diff --git a/query-algebrizer/src/clauses/not.rs b/query-algebrizer/src/clauses/not.rs index cd7ff46f..c4f31d6b 100644 --- a/query-algebrizer/src/clauses/not.rs +++ b/query-algebrizer/src/clauses/not.rs @@ -49,9 +49,7 @@ impl ConjoiningClauses { } } - for clause in not_join.clauses.into_iter() { - template.apply_clause(&schema, clause)?; - } + template.apply_clauses(&schema, not_join.clauses)?; if template.is_known_empty() { return Ok(()); diff --git a/query-algebrizer/src/clauses/or.rs b/query-algebrizer/src/clauses/or.rs index 7853121a..d4f73056 100644 --- a/query-algebrizer/src/clauses/or.rs +++ b/query-algebrizer/src/clauses/or.rs @@ -96,9 +96,7 @@ impl ConjoiningClauses { // [:find ?x :where (or (and [?x _ 5] [?x :foo/bar 7]))] // which is equivalent to dropping the `or` _and_ the `and`! OrWhereClause::And(clauses) => { - for clause in clauses { - self.apply_clause(schema, clause)?; - } + self.apply_clauses(schema, clauses)?; Ok(()) }, } @@ -564,9 +562,7 @@ impl ConjoiningClauses { let mut receptacle = template.make_receptacle(); match clause { OrWhereClause::And(clauses) => { - for clause in clauses { - receptacle.apply_clause(&schema, clause)?; - } + receptacle.apply_clauses(&schema, clauses)?; }, OrWhereClause::Clause(clause) => { receptacle.apply_clause(&schema, clause)?; diff --git a/query-algebrizer/src/clauses/pattern.rs b/query-algebrizer/src/clauses/pattern.rs index 353f9765..5bf746a6 100644 --- a/query-algebrizer/src/clauses/pattern.rs +++ b/query-algebrizer/src/clauses/pattern.rs @@ -201,7 +201,7 @@ impl ConjoiningClauses { } else { // It must be a keyword. self.constrain_column_to_constant(col.clone(), DatomsColumn::Value, TypedValue::Keyword(kw.clone())); - self.wheres.add_intersection(ColumnConstraint::has_type(col.clone(), ValueType::Keyword)); + self.wheres.add_intersection(ColumnConstraint::has_unit_type(col.clone(), ValueType::Keyword)); }; }, PatternValuePlace::Constant(ref c) => { @@ -237,7 +237,8 @@ impl ConjoiningClauses { // Because everything we handle here is unambiguous, we generate a single type // restriction from the value type of the typed value. if value_type.is_none() { - self.wheres.add_intersection(ColumnConstraint::has_type(col.clone(), typed_value_type)); + self.wheres.add_intersection( + ColumnConstraint::has_unit_type(col.clone(), typed_value_type)); } }, } @@ -445,7 +446,7 @@ mod testing { // TODO: implement expand_type_tags. assert_eq!(cc.wheres, vec![ ColumnConstraint::Equals(d0_v, QueryValue::TypedValue(TypedValue::Boolean(true))), - ColumnConstraint::has_type("datoms00".to_string(), ValueType::Boolean), + ColumnConstraint::has_unit_type("datoms00".to_string(), ValueType::Boolean), ].into()); } @@ -589,7 +590,7 @@ mod testing { // TODO: implement expand_type_tags. assert_eq!(cc.wheres, vec![ ColumnConstraint::Equals(d0_v, QueryValue::TypedValue(TypedValue::String(Rc::new("hello".to_string())))), - ColumnConstraint::has_type("all_datoms00".to_string(), ValueType::String), + ColumnConstraint::has_unit_type("all_datoms00".to_string(), ValueType::String), ].into()); } diff --git a/query-algebrizer/src/clauses/predicate.rs b/query-algebrizer/src/clauses/predicate.rs index 427774e1..ae89ec8e 100644 --- a/query-algebrizer/src/clauses/predicate.rs +++ b/query-algebrizer/src/clauses/predicate.rs @@ -34,7 +34,7 @@ use types::{ Inequality, }; -fn value_type_function_name(s: &str) -> Option { +pub fn parse_type_predicate(s: &str) -> Option { match s { "ref" => Some(ValueType::Ref), "boolean" => Some(ValueType::Boolean), @@ -62,8 +62,8 @@ impl ConjoiningClauses { // and ultimately allowing user-specified predicates, we match on the predicate name first. if let Some(op) = Inequality::from_datalog_operator(predicate.operator.0.as_str()) { self.apply_inequality(schema, op, predicate) - } else if let Some(ty) = value_type_function_name(predicate.operator.0.as_str()) { - self.apply_type_requirement(predicate, ty) + } else if let Some(ty) = parse_type_predicate(predicate.operator.0.as_str()) { + self.apply_type_requirement(&predicate, ty) } else { bail!(ErrorKind::UnknownFunction(predicate.operator.clone())) } @@ -76,14 +76,13 @@ impl ConjoiningClauses { } } - pub fn apply_type_requirement(&mut self, pred: Predicate, ty: ValueType) -> Result<()> { + pub fn apply_type_requirement(&mut self, pred: &Predicate, ty: ValueType) -> Result<()> { if pred.args.len() != 1 { bail!(ErrorKind::InvalidNumberOfArguments(pred.operator.clone(), pred.args.len(), 1)); } - let mut args = pred.args.into_iter(); - if let FnArg::Variable(v) = args.next().unwrap() { - self.add_type_requirement(v, ValueTypeSet::of_one(ty)); + if let &FnArg::Variable(ref v) = &pred.args[0] { + self.add_type_requirement(v.clone(), ValueTypeSet::of_one(ty)); Ok(()) } else { bail!(ErrorKind::InvalidArgument(pred.operator.clone(), "variable".into(), 0)) diff --git a/query-algebrizer/src/lib.rs b/query-algebrizer/src/lib.rs index 2199bfff..3bf22535 100644 --- a/query-algebrizer/src/lib.rs +++ b/query-algebrizer/src/lib.rs @@ -179,10 +179,8 @@ pub fn algebrize_with_inputs(schema: &Schema, // TODO: integrate default source into pattern processing. // TODO: flesh out the rest of find-into-context. - let where_clauses = parsed.where_clauses; - for where_clause in where_clauses { - cc.apply_clause(schema, where_clause)?; - } + cc.apply_clauses(schema, parsed.where_clauses)?; + cc.expand_column_bindings(); cc.prune_extracted_types(); cc.process_required_types()?; diff --git a/query-algebrizer/src/types.rs b/query-algebrizer/src/types.rs index d692fc3e..e37765dd 100644 --- a/query-algebrizer/src/types.rs +++ b/query-algebrizer/src/types.rs @@ -344,7 +344,7 @@ pub enum ColumnConstraint { } impl ColumnConstraint { - pub fn has_type(value: TableAlias, value_type: ValueType) -> ColumnConstraint { + pub fn has_unit_type(value: TableAlias, value_type: ValueType) -> ColumnConstraint { ColumnConstraint::HasTypes { value, value_types: ValueTypeSet::of_one(value_type), diff --git a/query-translator/tests/translate.rs b/query-translator/tests/translate.rs index 67406686..cc1ae2e9 100644 --- a/query-translator/tests/translate.rs +++ b/query-translator/tests/translate.rs @@ -293,10 +293,10 @@ fn test_type_required_long() { let query = r#"[:find ?x :where [?x _ ?e] [(long ?e)]]"#; let SQLQuery { sql, args } = translate(&schema, query); - assert_eq!(sql, "SELECT DISTINCT `all_datoms00`.e AS `?x` \ - FROM `all_datoms` AS `all_datoms00` \ - WHERE ((`all_datoms00`.value_type_tag = 5 AND \ - typeof(`all_datoms00`.v) = 'integer'))"); + assert_eq!(sql, "SELECT DISTINCT `datoms00`.e AS `?x` \ + FROM `datoms` AS `datoms00` \ + WHERE ((`datoms00`.value_type_tag = 5 AND \ + typeof(`datoms00`.v) = 'integer'))"); assert_eq!(args, vec![]); } @@ -308,10 +308,10 @@ fn test_type_required_double() { let query = r#"[:find ?x :where [?x _ ?e] [(double ?e)]]"#; let SQLQuery { sql, args } = translate(&schema, query); - assert_eq!(sql, "SELECT DISTINCT `all_datoms00`.e AS `?x` \ - FROM `all_datoms` AS `all_datoms00` \ - WHERE ((`all_datoms00`.value_type_tag = 5 AND \ - typeof(`all_datoms00`.v) = 'real'))"); + assert_eq!(sql, "SELECT DISTINCT `datoms00`.e AS `?x` \ + FROM `datoms` AS `datoms00` \ + WHERE ((`datoms00`.value_type_tag = 5 AND \ + typeof(`datoms00`.v) = 'real'))"); assert_eq!(args, vec![]); } @@ -323,29 +323,21 @@ fn test_type_required_boolean() { let query = r#"[:find ?x :where [?x _ ?e] [(boolean ?e)]]"#; let SQLQuery { sql, args } = translate(&schema, query); - assert_eq!(sql, "SELECT DISTINCT `all_datoms00`.e AS `?x` \ - FROM `all_datoms` AS `all_datoms00` \ - WHERE (`all_datoms00`.value_type_tag = 1)"); + assert_eq!(sql, "SELECT DISTINCT `datoms00`.e AS `?x` \ + FROM `datoms` AS `datoms00` \ + WHERE (`datoms00`.value_type_tag = 1)"); assert_eq!(args, vec![]); } #[test] -fn test_type_require_avoids_all_datoms() { +fn test_type_required_string() { let schema = Schema::default(); - // Since the constraint is first, we know we don't need to use all_datoms. - let query = r#"[:find ?x :where [(keyword ?e)] [?x _ ?e]]"#; - let SQLQuery { sql, args } = translate(&schema, query); - - assert_eq!(sql, "SELECT DISTINCT `datoms00`.e AS `?x` \ - FROM `datoms` AS `datoms00` \ - WHERE (`datoms00`.value_type_tag = 13)"); - assert_eq!(args, vec![]); - - // Strings always need to use all_datoms. - let query = r#"[:find ?x :where [(string ?e)] [?x _ ?e]]"#; + + let query = r#"[:find ?x :where [?x _ ?e] [(string ?e)]]"#; let SQLQuery { sql, args } = translate(&schema, query); + // Note: strings should use `all_datoms` and not `datoms`. assert_eq!(sql, "SELECT DISTINCT `all_datoms00`.e AS `?x` \ FROM `all_datoms` AS `all_datoms00` \ WHERE (`all_datoms00`.value_type_tag = 10)");