From 7024978517435c8124ff45dda7e643d0da2d3bb2 Mon Sep 17 00:00:00 2001 From: Richard Newman Date: Mon, 20 Mar 2017 14:11:32 +0000 Subject: [PATCH] Track ever-shrinking sets of types for variables, not a single type. (#381) r=nalexander --- query-algebrizer/src/cc.rs | 248 +++++++++++++++++++++++++--- query-algebrizer/src/errors.rs | 5 + query-projector/src/lib.rs | 2 +- query-translator/tests/translate.rs | 8 +- 4 files changed, 237 insertions(+), 26 deletions(-) diff --git a/query-algebrizer/src/cc.rs b/query-algebrizer/src/cc.rs index a37db09a..fdb35b14 100644 --- a/query-algebrizer/src/cc.rs +++ b/query-algebrizer/src/cc.rs @@ -19,8 +19,11 @@ use std::fmt::{ use std::collections::{ BTreeMap, BTreeSet, + HashSet, }; +use std::collections::btree_map::Entry; + use self::mentat_core::{ Attribute, Entid, @@ -36,6 +39,7 @@ use self::mentat_query::{ Pattern, PatternNonValuePlace, PatternValuePlace, + PlainSymbol, Predicate, SrcVar, Variable, @@ -83,6 +87,12 @@ impl OptionEffect for Option { } } +fn unit_type_set(t: ValueType) -> HashSet { + let mut s = HashSet::with_capacity(1); + s.insert(t); + s +} + /// A `ConjoiningClauses` (CC) is a collection of clauses that are combined with `JOIN`. /// The topmost form in a query is a `ConjoiningClauses`. /// @@ -143,7 +153,7 @@ pub struct ConjoiningClauses { /// A map from var to type. Whenever a var maps unambiguously to two different types, it cannot /// yield results, so we don't represent that case here. If a var isn't present in the map, it /// means that its type is not known in advance. - pub known_types: BTreeMap, + pub known_types: BTreeMap>, /// A mapping, similar to `column_bindings`, but used to pull type tags out of the store at runtime. /// If a var isn't present in `known_types`, it should be present here. @@ -153,7 +163,8 @@ pub struct ConjoiningClauses { #[derive(PartialEq)] pub enum EmptyBecause { // Var, existing, desired. - TypeMismatch(Variable, ValueType, ValueType), + TypeMismatch(Variable, HashSet, ValueType), + NonNumericArgument, UnresolvedIdent(NamespacedKeyword), InvalidAttributeIdent(NamespacedKeyword), InvalidAttributeEntid(Entid), @@ -170,6 +181,9 @@ impl Debug for EmptyBecause { write!(f, "Type mismatch: {:?} can't be {:?}, because it's already {:?}", var, desired, existing) }, + &NonNumericArgument => { + write!(f, "Non-numeric argument in numeric place") + }, &UnresolvedIdent(ref kw) => { write!(f, "Couldn't resolve keyword {}", kw) }, @@ -237,7 +251,7 @@ impl ConjoiningClauses { // Pre-fill our type mappings with the types of the input bindings. cc.known_types .extend(cc.value_bindings.iter() - .map(|(k, v)| (k.clone(), v.value_type()))); + .map(|(k, v)| (k.clone(), unit_type_set(v.value_type())))); cc } } @@ -247,6 +261,17 @@ impl ConjoiningClauses { self.value_bindings.get(var).cloned() } + /// Return a single `ValueType` if the given variable is known to have a precise type. + /// Returns `None` if the type of the variable is unknown. + /// Returns `None` if the type of the variable is known but not precise -- "double + /// or integer" isn't good enough. + pub fn known_type(&self, var: &Variable) -> Option { + match self.known_types.get(var) { + Some(types) if types.len() == 1 => types.iter().next().cloned(), + _ => None, + } + } + pub fn bind_column_to_var(&mut self, schema: &Schema, table: TableAlias, column: DatomsColumn, var: Variable) { // Do we have an external binding for this? if let Some(bound_val) = self.bound_value(&var) { @@ -291,7 +316,7 @@ impl ConjoiningClauses { let needs_type_extraction = !late_binding && // Never need to extract for bound vars. column == DatomsColumn::Value && // Never need to extract types for refs. - !self.known_types.contains_key(&var) && // We know the type! + self.known_type(&var).is_none() && // Don't need to extract if we know a single type. !self.extracted_types.contains_key(&var); // We're already extracting the type. let alias = QualifiedAlias(table, column); @@ -322,6 +347,36 @@ impl ConjoiningClauses { QueryValue::PrimitiveLong(value))) } + /// Mark the given value as one of the set of numeric types. + fn constrain_var_to_numeric(&mut self, variable: Variable) { + let mut numeric_types = HashSet::with_capacity(2); + numeric_types.insert(ValueType::Double); + numeric_types.insert(ValueType::Long); + + let entry = self.known_types.entry(variable); + match entry { + Entry::Vacant(vacant) => { + vacant.insert(numeric_types); + }, + Entry::Occupied(mut occupied) => { + let narrowed: HashSet = numeric_types.intersection(occupied.get()).cloned().collect(); + match narrowed.len() { + 0 => { + // TODO: can't borrow as mutable more than once! + //self.mark_known_empty(EmptyBecause::TypeMismatch(occupied.key().clone(), occupied.get().clone(), ValueType::Double)); // I know… + }, + 1 => { + // Hooray! + self.extracted_types.remove(occupied.key()); + }, + _ => { + }, + }; + occupied.insert(narrowed); + }, + } + } + /// Constrains the var if there's no existing type. /// Marks as known-empty if it's impossible for this type to apply because there's a conflicting /// type already known. @@ -343,13 +398,17 @@ impl ConjoiningClauses { // spot it here. if let Some(existing) = self.known_types.get(&variable).cloned() { // If so, the types must match. - if existing != this_type { + if !existing.contains(&this_type) { self.mark_known_empty(EmptyBecause::TypeMismatch(variable, existing, this_type)); + } else { + if existing.len() > 1 { + // Narrow. + self.known_types.insert(variable, unit_type_set(this_type)); + } } } else { // If not, record the one we just determined. - self.known_types.insert(variable, this_type); - + self.known_types.insert(variable, unit_type_set(this_type)); } } @@ -424,8 +483,14 @@ impl ConjoiningClauses { match value { // TODO: see if the variable is projected, aggregated, or compared elsewhere in // the query. If it's not, we don't need to use all_datoms here. - &PatternValuePlace::Variable(_) => - DatomsTable::AllDatoms, // TODO: check types. + &PatternValuePlace::Variable(ref v) => { + // Do we know that this variable can't be a string? If so, we don't need + // AllDatoms. None or String means it could be or definitely is. + match self.known_types.get(v).map(|types| types.contains(&ValueType::String)) { + Some(false) => DatomsTable::Datoms, + _ => DatomsTable::AllDatoms, + } + } &PatternValuePlace::Constant(NonIntegerConstant::Text(_)) => DatomsTable::AllDatoms, _ => @@ -792,6 +857,38 @@ impl ConjoiningClauses { /// Take a function argument and turn it into a `QueryValue` suitable for use in a concrete /// constraint. + /// Additionally, do two things: + /// - Mark the pattern as known-empty if any argument is known non-numeric. + /// - Mark any variables encountered as numeric. + fn resolve_numeric_argument(&mut self, function: &PlainSymbol, position: usize, arg: FnArg) -> Result { + use self::FnArg::*; + match arg { + FnArg::Variable(var) => { + self.constrain_var_to_numeric(var.clone()); + self.column_bindings + .get(&var) + .and_then(|cols| cols.first().map(|col| QueryValue::Column(col.clone()))) + .ok_or_else(|| Error::from_kind(ErrorKind::UnboundVariable(var))) + }, + // Can't be an entid. + EntidOrInteger(i) => Ok(QueryValue::TypedValue(TypedValue::Long(i))), + Ident(_) | + SrcVar(_) | + Constant(NonIntegerConstant::Boolean(_)) | + Constant(NonIntegerConstant::Text(_)) | + Constant(NonIntegerConstant::BigInteger(_)) => { + self.mark_known_empty(EmptyBecause::NonNumericArgument); + // We use Double because… well, we only have one slot! + bail!(ErrorKind::NonNumericArgument(function.clone(), position)); + }, + Constant(NonIntegerConstant::Float(f)) => Ok(QueryValue::TypedValue(TypedValue::Double(f))), + } + } + + + /// Take a function argument and turn it into a `QueryValue` suitable for use in a concrete + /// constraint. + #[allow(dead_code)] fn resolve_argument(&self, arg: FnArg) -> Result { use self::FnArg::*; match arg { @@ -843,8 +940,8 @@ impl ConjoiningClauses { // Any variables that aren't bound by this point in the linear processing of clauses will // cause the application of the predicate to fail. let mut args = predicate.args.into_iter(); - let left = self.resolve_argument(args.next().unwrap())?; - let right = self.resolve_argument(args.next().unwrap())?; + let left = self.resolve_numeric_argument(&predicate.operator, 0, args.next().unwrap())?; + let right = self.resolve_numeric_argument(&predicate.operator, 1, args.next().unwrap())?; // These arguments must be variables or numeric constants. // TODO: generalize argument resolution and validation for different kinds of predicates: @@ -944,7 +1041,7 @@ mod testing { assert_eq!(cc.from, vec![SourceAlias(DatomsTable::Datoms, "datoms00".to_string())]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), &vec![d0_e.clone()]); @@ -982,7 +1079,7 @@ mod testing { assert_eq!(cc.from, vec![SourceAlias(DatomsTable::Datoms, "datoms00".to_string())]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), &vec![d0_e.clone()]); @@ -1031,7 +1128,7 @@ mod testing { assert_eq!(cc.from, vec![SourceAlias(DatomsTable::Datoms, "datoms00".to_string())]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), &vec![d0_e.clone()]); @@ -1091,7 +1188,7 @@ mod testing { assert_eq!(cc.from, vec![SourceAlias(DatomsTable::AllDatoms, "all_datoms00".to_string())]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), &vec![d0_e.clone()]); @@ -1122,7 +1219,7 @@ mod testing { assert_eq!(cc.from, vec![SourceAlias(DatomsTable::AllDatoms, "all_datoms00".to_string())]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), &vec![d0_e.clone()]); @@ -1188,7 +1285,7 @@ mod testing { ]); // ?x must be a ref. - assert_eq!(cc.known_types.get(&x).unwrap(), &ValueType::Ref); + assert_eq!(cc.known_type(&x).unwrap(), ValueType::Ref); // ?x is bound to datoms0.e and datoms1.e. assert_eq!(cc.column_bindings.get(&x).unwrap(), @@ -1316,6 +1413,119 @@ mod testing { assert!(cc.is_known_empty); } + #[test] + /// Apply two patterns: a pattern and a numeric predicate. + /// Verify that after application of the predicate we know that the value + /// must be numeric. + fn test_apply_numeric_predicate() { + let mut cc = ConjoiningClauses::default(); + let mut schema = Schema::default(); + + associate_ident(&mut schema, NamespacedKeyword::new("foo", "bar"), 99); + add_attribute(&mut schema, 99, Attribute { + value_type: ValueType::Long, + ..Default::default() + }); + + let x = Variable(PlainSymbol::new("?x")); + let y = Variable(PlainSymbol::new("?y")); + cc.apply_pattern(&schema, Pattern { + source: None, + entity: PatternNonValuePlace::Variable(x.clone()), + attribute: PatternNonValuePlace::Placeholder, + value: PatternValuePlace::Variable(y.clone()), + tx: PatternNonValuePlace::Placeholder, + }); + assert!(!cc.is_known_empty); + + let op = PlainSymbol::new("<"); + let comp = NumericComparison::from_datalog_operator(op.plain_name()).unwrap(); + assert!(cc.apply_numeric_predicate(&schema, comp, Predicate { + operator: op, + args: vec![ + FnArg::Variable(Variable(PlainSymbol::new("?y"))), FnArg::EntidOrInteger(10), + ]}).is_ok()); + + assert!(!cc.is_known_empty); + + // Finally, expand column bindings to get the overlaps for ?x. + cc.expand_column_bindings(); + assert!(!cc.is_known_empty); + + // After processing those two clauses, we know that ?y must be numeric, but not exactly + // which type it must be. + assert_eq!(None, cc.known_type(&y)); // Not just one. + let expected: HashSet = vec![ValueType::Double, ValueType::Long].into_iter().collect(); + assert_eq!(Some(&expected), cc.known_types.get(&y)); + + let clauses = cc.wheres; + assert_eq!(clauses.len(), 1); + assert_eq!(clauses[0], ColumnConstraint::NumericInequality { + operator: NumericComparison::LessThan, + left: QueryValue::Column(cc.column_bindings.get(&y).unwrap()[0].clone()), + right: QueryValue::TypedValue(TypedValue::Long(10)), + }); + } + + #[test] + /// Apply three patterns: an unbound pattern to establish a value var, + /// a predicate to constrain the val to numeric types, and a third pattern to conflict with the + /// numeric types and cause the pattern to fail. + fn test_apply_conflict_with_numeric_range() { + let mut cc = ConjoiningClauses::default(); + let mut schema = Schema::default(); + + associate_ident(&mut schema, NamespacedKeyword::new("foo", "bar"), 99); + associate_ident(&mut schema, NamespacedKeyword::new("foo", "roz"), 98); + add_attribute(&mut schema, 99, Attribute { + value_type: ValueType::Long, + ..Default::default() + }); + add_attribute(&mut schema, 98, Attribute { + value_type: ValueType::String, + unique: Some(Unique::Identity), + ..Default::default() + }); + + let x = Variable(PlainSymbol::new("?x")); + let y = Variable(PlainSymbol::new("?y")); + cc.apply_pattern(&schema, Pattern { + source: None, + entity: PatternNonValuePlace::Variable(x.clone()), + attribute: PatternNonValuePlace::Placeholder, + value: PatternValuePlace::Variable(y.clone()), + tx: PatternNonValuePlace::Placeholder, + }); + assert!(!cc.is_known_empty); + + let op = PlainSymbol::new(">="); + let comp = NumericComparison::from_datalog_operator(op.plain_name()).unwrap(); + assert!(cc.apply_numeric_predicate(&schema, comp, Predicate { + operator: op, + args: vec![ + FnArg::Variable(Variable(PlainSymbol::new("?y"))), FnArg::EntidOrInteger(10), + ]}).is_ok()); + + assert!(!cc.is_known_empty); + cc.apply_pattern(&schema, Pattern { + source: None, + entity: PatternNonValuePlace::Variable(x.clone()), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("foo", "roz")), + value: PatternValuePlace::Variable(y.clone()), + tx: PatternNonValuePlace::Placeholder, + }); + + // Finally, expand column bindings to get the overlaps for ?x. + cc.expand_column_bindings(); + + assert!(cc.is_known_empty); + assert_eq!(cc.empty_because.unwrap(), + EmptyBecause::TypeMismatch(y.clone(), + vec![ValueType::Double, ValueType::Long].into_iter() + .collect(), + ValueType::String)); + } + #[test] /// Apply two patterns with differently typed attributes, but sharing a variable in the value /// place. No value can bind to a variable and match both types, so the CC is known to return @@ -1358,7 +1568,7 @@ mod testing { assert!(cc.is_known_empty); assert_eq!(cc.empty_because.unwrap(), - EmptyBecause::TypeMismatch(y.clone(), ValueType::String, ValueType::Boolean)); + EmptyBecause::TypeMismatch(y.clone(), unit_type_set(ValueType::String), ValueType::Boolean)); } #[test] @@ -1396,6 +1606,6 @@ mod testing { assert!(cc.is_known_empty); assert_eq!(cc.empty_because.unwrap(), - EmptyBecause::TypeMismatch(x.clone(), ValueType::Ref, ValueType::Boolean)); + EmptyBecause::TypeMismatch(x.clone(), unit_type_set(ValueType::Ref), ValueType::Boolean)); } } diff --git a/query-algebrizer/src/errors.rs b/query-algebrizer/src/errors.rs index 1db9b371..63904660 100644 --- a/query-algebrizer/src/errors.rs +++ b/query-algebrizer/src/errors.rs @@ -35,6 +35,11 @@ error_chain! { description("unbound variable in function call") display("unbound variable: {}", var.0) } + + NonNumericArgument(function: PlainSymbol, position: usize) { + description("invalid argument") + display("invalid argument to {}: expected numeric in position {}.", function, position) + } } } diff --git a/query-projector/src/lib.rs b/query-projector/src/lib.rs index fa0b4e3b..bdfcb2e7 100644 --- a/query-projector/src/lib.rs +++ b/query-projector/src/lib.rs @@ -219,7 +219,7 @@ fn project_elements<'a, I: IntoIterator>( let qa = columns[0].clone(); let name = column_name(var); - if let Some(t) = query.cc.known_types.get(var) { + if let Some(t) = query.cc.known_type(var) { cols.push(ProjectedColumn(ColumnOrExpression::Column(qa), name)); let tag = t.value_type_tag(); templates.push(TypedIndex::Known(i, tag)); diff --git a/query-translator/tests/translate.rs b/query-translator/tests/translate.rs index ac7b1bb7..145fddfd 100644 --- a/query-translator/tests/translate.rs +++ b/query-translator/tests/translate.rs @@ -186,9 +186,8 @@ fn test_numeric_less_than_unknown_attribute() { let input = r#"[:find ?x :where [?x _ ?y] [(< ?y 10)]]"#; let SQLQuery { sql, args } = translate(&schema, input, None); - // TODO: we don't infer numeric types from numeric predicates, because the _SQL_ type code - // is a single value (5), but the Datalog types are a set (Double and Long). - // When we do, this will correctly use `datoms` instead of `all_datoms`. + // Although we infer numericness from numeric predicates, we've already assigned a table to the + // first pattern, and so this is _still_ `all_datoms`. assert_eq!(sql, "SELECT `all_datoms00`.e AS `?x` FROM `all_datoms` AS `all_datoms00` WHERE `all_datoms00`.v < 10"); assert_eq!(args, vec![]); } @@ -202,14 +201,12 @@ fn test_numeric_gte_known_attribute() { ..Default::default() }); - let input = r#"[:find ?x :where [?x :foo/bar ?y] [(>= ?y 12.9)]]"#; let SQLQuery { sql, args } = translate(&schema, input, None); assert_eq!(sql, "SELECT `datoms00`.e AS `?x` FROM `datoms` AS `datoms00` WHERE `datoms00`.a = 99 AND `datoms00`.v >= 12.9"); assert_eq!(args, vec![]); } - #[test] fn test_numeric_not_equals_known_attribute() { let mut schema = Schema::default(); @@ -219,7 +216,6 @@ fn test_numeric_not_equals_known_attribute() { ..Default::default() }); - let input = r#"[:find ?x :where [?x :foo/bar ?y] [(!= ?y 12)]]"#; let SQLQuery { sql, args } = translate(&schema, input, None); assert_eq!(sql, "SELECT `datoms00`.e AS `?x` FROM `datoms` AS `datoms00` WHERE `datoms00`.a = 99 AND `datoms00`.v <> 12");