diff --git a/query-parser/src/parse.rs b/query-parser/src/parse.rs index f5b6a8d9..beee7a75 100644 --- a/query-parser/src/parse.rs +++ b/query-parser/src/parse.rs @@ -42,6 +42,8 @@ use self::mentat_query::{ FromValue, OrJoin, OrWhereClause, + NotJoin, + WhereNotClause, Pattern, PatternNonValuePlace, PatternValuePlace, @@ -135,6 +137,10 @@ def_matches_plain_symbol!(Where, or, "or"); def_matches_plain_symbol!(Where, or_join, "or-join"); +def_matches_plain_symbol!(Where, or_join, "not"); + +def_matches_plain_symbol!(Where, or_join, "not-join"); + def_parser!(Where, rule_vars, Vec, { seq() .of_exactly(many1(Query::variable())) @@ -182,6 +188,53 @@ def_parser!(Where, or_join_clause, WhereClause, { })) }); +def_value_parser_fn!(Where, not_pattern_clause, WhereNotClause, input, { + Where::clause().map(|clause| WhereNotClause::Clause(clause)).parse_stream(input) +}); + +def_value_parser_fn!(Where, where_not_clause, WhereNotClause, input, { + choice([Where::not_pattern_clause()]).parse_stream(input) +}); + +def_value_parser_fn!(Where, not_clause, WhereClause, input, { + satisfy_map(|x: edn::Value| { + seq(x).and_then(|items| { + let mut p = Where::not() + .with(many1(Where::where_not_clause())) + .skip(eof()) + .map(|clauses| { + WhereClause::NotJoin( + NotJoin { + unify_vars: UnifyVars::Implicit, + clauses: clauses, + }) + }); + let r: ParseResult = p.parse_lazy(&items[..]).into(); + Query::to_parsed_value(r) + }) + }).parse_stream(input) +}); + +def_value_parser_fn!(Where, not_join_clause, WhereClause, input, { + satisfy_map(|x: edn::Value| { + seq(x).and_then(|items| { + let mut p = Where::not_join() + .with(Where::rule_vars()) + .and(many1(Where::where_not_clause())) + .skip(eof()) + .map(|(vars, clauses)| { + WhereClause::NotJoin( + NotJoin { + unify_vars: UnifyVars::Explicit(vars), + clauses: clauses, + }) + }); + let r: ParseResult = p.parse_lazy(&items[..]).into(); + Query::to_parsed_value(r) + }) + }).parse_stream(input) +}); + /// A vector containing just a parenthesized filter expression. def_parser!(Where, pred, WhereClause, { // Accept either a nested list or a nested vector here: @@ -246,6 +299,8 @@ def_parser!(Where, clause, WhereClause, { // We don't yet handle source vars. try(Where::or_join_clause()), try(Where::or_clause()), + try(Where::not_join_clause()), + try(Where::not_clause()), try(Where::pred()), ]) @@ -548,6 +603,59 @@ mod test { })); } + #[test] + fn test_not() { + let oj = edn::PlainSymbol::new("not"); + let e = edn::PlainSymbol::new("?e"); + let a = edn::PlainSymbol::new("?a"); + let v = edn::PlainSymbol::new("?v"); + let input = [edn::Value::List( + vec![edn::Value::PlainSymbol(oj), + edn::Value::Vector(vec![edn::Value::PlainSymbol(e.clone()), + edn::Value::PlainSymbol(a.clone()), + edn::Value::PlainSymbol(v.clone())])].into_iter().collect())]; + assert_parses_to!(Where::not_clause, input, + WhereClause::NotJoin( + NotJoin { + unify_vars: UnifyVars::Implicit, + clauses: vec![WhereNotClause::Clause( + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(variable(e)), + attribute: PatternNonValuePlace::Variable(variable(a)), + value: PatternValuePlace::Variable(variable(v)), + tx: PatternNonValuePlace::Placeholder, + }))], + })); + } + + #[test] + fn test_not_join() { + let oj = edn::PlainSymbol::new("not-join"); + let e = edn::PlainSymbol::new("?e"); + let a = edn::PlainSymbol::new("?a"); + let v = edn::PlainSymbol::new("?v"); + let input = [edn::Value::List( + vec![edn::Value::PlainSymbol(oj), + edn::Value::Vector(vec![edn::Value::PlainSymbol(e.clone())]), + edn::Value::Vector(vec![edn::Value::PlainSymbol(e.clone()), + edn::Value::PlainSymbol(a.clone()), + edn::Value::PlainSymbol(v.clone())])].into_iter().collect())]; + assert_parses_to!(Where::not_join_clause, input, + WhereClause::NotJoin( + NotJoin { + unify_vars: UnifyVars::Explicit(vec![variable(e.clone())]), + clauses: vec![WhereNotClause::Clause( + WhereClause::Pattern(Pattern { + source: None, + entity: PatternNonValuePlace::Variable(variable(e)), + attribute: PatternNonValuePlace::Variable(variable(a)), + value: PatternValuePlace::Variable(variable(v)), + tx: PatternNonValuePlace::Placeholder, + }))], + })); + } + #[test] fn test_find_sp_variable() { let sym = edn::PlainSymbol::new("?x"); diff --git a/query/src/lib.rs b/query/src/lib.rs index b9037117..ebea7376 100644 --- a/query/src/lib.rs +++ b/query/src/lib.rs @@ -570,11 +570,30 @@ pub struct OrJoin { pub clauses: Vec, } +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum WhereNotClause { + Clause(WhereClause), +} + +impl WhereNotClause { + pub fn is_pattern_or_patterns(&self) -> bool { + match self { + &WhereNotClause::Clause(WhereClause::Pattern(_)) => true, + _ => false, + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct NotJoin { + pub unify_vars: UnifyVars, + pub clauses: Vec, +} + #[allow(dead_code)] #[derive(Clone, Debug, Eq, PartialEq)] pub enum WhereClause { - Not, - NotJoin, + NotJoin(NotJoin), OrJoin(OrJoin), Pred(Predicate), WhereFn, @@ -628,8 +647,7 @@ impl ContainsVariables for WhereClause { &OrJoin(ref o) => o.accumulate_mentioned_variables(acc), &Pred(ref p) => p.accumulate_mentioned_variables(acc), &Pattern(ref p) => p.accumulate_mentioned_variables(acc), - &Not => (), - &NotJoin => (), + &NotJoin(ref n) => n.accumulate_mentioned_variables(acc), &WhereFn => (), &RuleExpr => (), } @@ -654,6 +672,23 @@ impl ContainsVariables for OrJoin { } } +impl ContainsVariables for NotJoin { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + for clause in &self.clauses { + clause.accumulate_mentioned_variables(acc); + } + } +} + +impl ContainsVariables for WhereNotClause { + fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { + use WhereNotClause::*; + match self { + &Clause(ref clause) => clause.accumulate_mentioned_variables(acc), + } + } +} + impl ContainsVariables for Predicate { fn accumulate_mentioned_variables(&self, acc: &mut BTreeSet) { for arg in &self.args {