diff --git a/query-algebrizer/Cargo.toml b/query-algebrizer/Cargo.toml index 4a1dfeb0..71c15f50 100644 --- a/query-algebrizer/Cargo.toml +++ b/query-algebrizer/Cargo.toml @@ -11,3 +11,7 @@ path = "../core" [dependencies.mentat_query] path = "../query" + +# Only for tests. +[dev-dependencies.mentat_query_parser] +path = "../query-parser" \ No newline at end of file diff --git a/query-algebrizer/src/cc.rs b/query-algebrizer/src/cc.rs index a74987d6..f09ecf4e 100644 --- a/query-algebrizer/src/cc.rs +++ b/query-algebrizer/src/cc.rs @@ -8,9 +8,6 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -extern crate mentat_core; -extern crate mentat_query; - use std::fmt::{ Debug, Formatter, @@ -24,7 +21,7 @@ use std::collections::{ use std::collections::btree_map::Entry; -use self::mentat_core::{ +use mentat_core::{ Attribute, Entid, Schema, @@ -32,7 +29,7 @@ use self::mentat_core::{ ValueType, }; -use self::mentat_query::{ +use mentat_query::{ FnArg, NamespacedKeyword, NonIntegerConstant, @@ -43,6 +40,7 @@ use self::mentat_query::{ Predicate, SrcVar, Variable, + WhereClause, }; use errors::{ @@ -62,6 +60,8 @@ use types::{ TableAlias, }; +use validate::validate_or_join; + /// A thing that's capable of aliasing a table name for us. /// This exists so that we can obtain predictable names in tests. pub type TableAliaser = Box TableAlias>; @@ -958,6 +958,27 @@ impl ConjoiningClauses { } } +impl ConjoiningClauses { + // This is here, rather than in `lib.rs`, because it's recursive: `or` can contain `or`, + // and so on. + pub fn apply_clause(&mut self, schema: &Schema, where_clause: WhereClause) -> Result<()> { + match where_clause { + WhereClause::Pattern(p) => { + self.apply_pattern(schema, p); + Ok(()) + }, + WhereClause::Pred(p) => { + self.apply_predicate(schema, p) + }, + WhereClause::OrJoin(o) => { + validate_or_join(&o) + // TODO: apply. + }, + _ => unimplemented!(), + } + } +} + #[cfg(test)] mod testing { use super::*; diff --git a/query-algebrizer/src/errors.rs b/query-algebrizer/src/errors.rs index 63904660..742f54a9 100644 --- a/query-algebrizer/src/errors.rs +++ b/query-algebrizer/src/errors.rs @@ -40,6 +40,12 @@ error_chain! { description("invalid argument") display("invalid argument to {}: expected numeric in position {}.", function, position) } + + NonMatchingVariablesInOrClause { + // TODO: flesh out. + description("non-matching variables in 'or' clause") + display("non-matching variables in 'or' clause") + } } } diff --git a/query-algebrizer/src/lib.rs b/query-algebrizer/src/lib.rs index a5660944..d6d92417 100644 --- a/query-algebrizer/src/lib.rs +++ b/query-algebrizer/src/lib.rs @@ -16,6 +16,7 @@ extern crate mentat_query; mod errors; mod types; +mod validate; mod cc; use mentat_core::{ @@ -74,15 +75,7 @@ pub fn algebrize(schema: &Schema, parsed: FindQuery) -> Result { let mut cc = cc::ConjoiningClauses::default(); let where_clauses = parsed.where_clauses; for where_clause in where_clauses { - match where_clause { - WhereClause::Pattern(p) => { - cc.apply_pattern(schema, p); - }, - WhereClause::Pred(p) => { - cc.apply_predicate(schema, p)?; - }, - _ => unimplemented!(), - } + cc.apply_clause(schema, where_clause)?; } let limit = if parsed.find_spec.is_unit_limited() { Some(1) } else { None }; diff --git a/query-algebrizer/src/validate.rs b/query-algebrizer/src/validate.rs new file mode 100644 index 00000000..5fea33c1 --- /dev/null +++ b/query-algebrizer/src/validate.rs @@ -0,0 +1,229 @@ +// Copyright 2016 Mozilla +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +use std::collections::BTreeSet; + +use mentat_query::{ + ContainsVariables, + OrJoin, + Variable, + WhereClause, + UnifyVars, +}; + +use errors::{ + ErrorKind, + Result, +}; + +/// In an `or` expression, every mentioned var is considered 'free'. +/// In an `or-join` expression, every var in the var list is 'required'. +/// +/// Every extracted variable must be used in the clauses. +/// The extracted var list cannot be empty. +/// +/// The original Datomic docs are poorly worded: +/// +/// "All clauses used in an or clause must use the same set of variables, which will unify with the +/// surrounding query. This includes both the arguments to nested expression clauses as well as any +/// bindings made by nested function expressions. Datomic will attempt to push the or clause down +/// until all necessary variables are bound, and will throw an exception if that is not possible." +/// +/// What this really means is: each pattern in the `or-join` clause must use the var list and unify +/// with the surrounding query. It does not mean that each leg must have the same set of vars. +/// +/// An `or` pattern must, because the set of vars is defined as every var mentioned in any clause, +/// so naturally they must all be the same. +/// +/// "As with rules, src-vars are not currently supported within the clauses of or, but are supported +/// on the or clause as a whole at top level." +pub fn validate_or_join(or_join: &OrJoin) -> Result<()> { + // Grab our mentioned variables and ensure that the rules are followed. + match or_join.unify_vars { + UnifyVars::Implicit => { + // Each 'leg' must have the same variable set. + if or_join.clauses.len() < 2 { + Ok(()) + } else { + let mut clauses = or_join.clauses.iter(); + let template = clauses.next().unwrap().collect_mentioned_variables(); + for clause in clauses { + if template != clause.collect_mentioned_variables() { + bail!(ErrorKind::NonMatchingVariablesInOrClause); + } + } + Ok(()) + } + }, + UnifyVars::Explicit(ref vars) => { + // Each leg must use the joined vars. + let var_set: BTreeSet = vars.iter().cloned().collect(); + for clause in &or_join.clauses { + if !var_set.is_subset(&clause.collect_mentioned_variables()) { + bail!(ErrorKind::NonMatchingVariablesInOrClause); + } + } + Ok(()) + }, + } +} + +#[cfg(test)] +mod tests { + extern crate mentat_core; + extern crate mentat_query; + extern crate mentat_query_parser; + + use self::mentat_query::{ + FindQuery, + NamespacedKeyword, + OrWhereClause, + Pattern, + PatternNonValuePlace, + PatternValuePlace, + PlainSymbol, + SrcVar, + UnifyVars, + Variable, + WhereClause, + }; + + use self::mentat_query_parser::parse_find_string; + + use super::validate_or_join; + + /// Tests that the top-level form is a valid `or`, returning the clauses. + fn valid_or_join(parsed: FindQuery, expected_unify: UnifyVars) -> Vec { + let mut wheres = parsed.where_clauses.into_iter(); + + // There's only one. + let clause = wheres.next().unwrap(); + assert_eq!(None, wheres.next()); + + match clause { + WhereClause::OrJoin(or_join) => { + // It's valid: the variables are the same in each branch. + assert_eq!((), validate_or_join(&or_join).unwrap()); + assert_eq!(expected_unify, or_join.unify_vars); + or_join.clauses + }, + _ => panic!(), + } + } + + /// Test that an `or` is valid if all of its arms refer to the same variables. + #[test] + fn test_success_or() { + let query = r#"[:find [?artist ...] + :where (or [?artist :artist/type :artist.type/group] + (and [?artist :artist/type :artist.type/person] + [?artist :artist/gender :artist.gender/female]))]"#; + let parsed = parse_find_string(query).expect("expected successful parse"); + let clauses = valid_or_join(parsed, UnifyVars::Implicit); + + // Let's do some detailed parse checks. + let mut arms = clauses.into_iter(); + match (arms.next(), arms.next(), arms.next()) { + (Some(left), Some(right), None) => { + assert_eq!( + left, + OrWhereClause::Clause(WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?artist"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "type")), + value: PatternValuePlace::IdentOrKeyword(NamespacedKeyword::new("artist.type", "group")), + tx: PatternNonValuePlace::Placeholder, + }))); + assert_eq!( + right, + OrWhereClause::And( + vec![ + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?artist"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "type")), + value: PatternValuePlace::IdentOrKeyword(NamespacedKeyword::new("artist.type", "person")), + tx: PatternNonValuePlace::Placeholder, + }), + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?artist"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "gender")), + value: PatternValuePlace::IdentOrKeyword(NamespacedKeyword::new("artist.gender", "female")), + tx: PatternNonValuePlace::Placeholder, + }), + ])); + }, + _ => panic!(), + }; + } + + /// Test that an `or` with differing variable sets in each arm will fail to validate. + #[test] + fn test_invalid_implicit_or() { + let query = r#"[:find [?artist ...] + :where (or [?artist :artist/type :artist.type/group] + [?artist :artist/type ?type])]"#; + let parsed = parse_find_string(query).expect("expected successful parse"); + match parsed.where_clauses.into_iter().next().expect("expected at least one clause") { + WhereClause::OrJoin(or_join) => assert!(validate_or_join(&or_join).is_err()), + _ => panic!(), + } + } + + /// Test that two arms of an `or-join` can contain different variables if they both + /// contain the required `or-join` list. + #[test] + fn test_success_differing_or_join() { + let query = r#"[:find [?artist ...] + :where (or-join [?artist] + [?artist :artist/type :artist.type/group] + (and [?artist :artist/type ?type] + [?type :artist/role :artist.role/parody]))]"#; + let parsed = parse_find_string(query).expect("expected successful parse"); + let clauses = valid_or_join(parsed, UnifyVars::Explicit(vec![Variable(PlainSymbol::new("?artist"))])); + + // Let's do some detailed parse checks. + let mut arms = clauses.into_iter(); + match (arms.next(), arms.next(), arms.next()) { + (Some(left), Some(right), None) => { + assert_eq!( + left, + OrWhereClause::Clause(WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?artist"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "type")), + value: PatternValuePlace::IdentOrKeyword(NamespacedKeyword::new("artist.type", "group")), + tx: PatternNonValuePlace::Placeholder, + }))); + assert_eq!( + right, + OrWhereClause::And( + vec![ + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?artist"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "type")), + value: PatternValuePlace::Variable(Variable(PlainSymbol::new("?type"))), + tx: PatternNonValuePlace::Placeholder, + }), + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(Variable(PlainSymbol::new("?type"))), + attribute: PatternNonValuePlace::Ident(NamespacedKeyword::new("artist", "role")), + value: PatternValuePlace::IdentOrKeyword(NamespacedKeyword::new("artist.role", "parody")), + tx: PatternNonValuePlace::Placeholder, + }), + ])); + }, + _ => panic!(), + }; + } +} \ No newline at end of file diff --git a/query/src/lib.rs b/query/src/lib.rs index 3bc0eaef..dc3126d8 100644 --- a/query/src/lib.rs +++ b/query/src/lib.rs @@ -33,6 +33,7 @@ extern crate edn; extern crate mentat_core; +use std::collections::BTreeSet; use std::fmt; use edn::{BigInt, OrderedFloat}; pub use edn::{NamespacedKeyword, PlainSymbol}; @@ -538,3 +539,79 @@ pub struct FindQuery { pub where_clauses: Vec, // TODO: in_rules; } + +pub trait ContainsVariables { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet); + fn collect_mentioned_variables(&self) -> BTreeSet { + let mut out = BTreeSet::new(); + self.accumulate_mentioned_variables(&mut out); + out + } +} + +impl ContainsVariables for WhereClause { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + use WhereClause::*; + match self { + &OrJoin(ref o) => o.accumulate_mentioned_variables(acc), + &Pred(ref p) => p.accumulate_mentioned_variables(acc), + &Pattern(ref p) => p.accumulate_mentioned_variables(acc), + &Not => (), + &NotJoin => (), + &WhereFn => (), + &RuleExpr => (), + } + } +} + +impl ContainsVariables for OrWhereClause { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + use OrWhereClause::*; + match self { + &And(ref clauses) => for clause in clauses { clause.accumulate_mentioned_variables(acc) }, + &Clause(ref clause) => clause.accumulate_mentioned_variables(acc), + } + } +} + +impl ContainsVariables for OrJoin { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + for clause in &self.clauses { + clause.accumulate_mentioned_variables(acc); + } + } +} + +impl ContainsVariables for Predicate { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + for arg in &self.args { + if let &FnArg::Variable(ref v) = arg { + acc_ref(acc, v) + } + } + } +} + +fn acc_ref(acc: &mut BTreeSet, v: &T) { + // Roll on, reference entries! + if !acc.contains(v) { + acc.insert(v.clone()); + } +} + +impl ContainsVariables for Pattern { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + if let PatternNonValuePlace::Variable(ref v) = self.entity { + acc_ref(acc, v) + } + if let PatternNonValuePlace::Variable(ref v) = self.attribute { + acc_ref(acc, v) + } + if let PatternValuePlace::Variable(ref v) = self.value { + acc_ref(acc, v) + } + if let PatternNonValuePlace::Variable(ref v) = self.tx { + acc_ref(acc, v) + } + } +} \ No newline at end of file