diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -1071,6 +1071,13 @@ LogicalResult getRelationFromMap(const AffineValueMap &map, FlatAffineRelation &rel); +/// This parses a single IntegerSet to an MLIR context and transforms it to +/// FlatAffineConstraints if it was valid. If not, a failure is returned. If the +/// passed `str` has additional tokens that were not part of the IntegerSet, a +/// failure is returned. +FailureOr +parseFlatAffineConstraints(llvm::StringRef, MLIRContext *context); + } // end namespace mlir. #endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -256,6 +256,14 @@ /// `typeStr`. The number of characters of `typeStr` parsed in the process is /// returned in `numRead`. Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t &numRead); + +/// This parses a single IntegerSet to an MLIR context if it was valid. If not, +/// an error message is emitted through a new SourceMgrDiagnosticHandler +/// constructed from a new SourceMgr with a single MemoryBuffer wrapping +/// `str`. If the passed `str` has additional tokens that were not part of the +/// IntegerSet, a failure is returned. +IntegerSet parseIntegerSet(llvm::StringRef str, MLIRContext *context); + } // end namespace mlir #endif // MLIR_PARSER_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/STLExtras.h" @@ -3827,3 +3828,13 @@ return success(); } + +FailureOr +mlir::parseFlatAffineConstraints(llvm::StringRef str, MLIRContext *context) { + IntegerSet set = parseIntegerSet(str, context); + + if (!set) + return failure(); + + return FlatAffineConstraints(set); +} diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -66,6 +66,7 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRInferTypeOpInterface + MLIRParser MLIRPresburger MLIRSCF ) diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -13,10 +13,13 @@ #include "Parser.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace mlir::detail; +using llvm::MemoryBuffer; using llvm::SMLoc; +using llvm::SourceMgr; namespace { @@ -717,3 +720,26 @@ return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) .parseAffineExprOfSSAIds(expr); } + +IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context) { + llvm::SourceMgr sourceMgr; + auto memBuffer = llvm::MemoryBuffer::getMemBuffer( + inputStr, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + SymbolState symbolState; + ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); + Parser parser(state); + SourceMgrDiagnosticHandler handler(sourceMgr, context); + IntegerSet set; + if (parser.parseIntegerSetReference(set)) + return IntegerSet(); + + Token endTok = parser.getToken(); + if (endTok.isNot(Token::eof)) { + parser.emitError(endTok.getLoc(), "encountered unexpected token"); + return IntegerSet(); + } + + return set; +} diff --git a/mlir/unittests/Analysis/AffineStructuresParserTest.cpp b/mlir/unittests/Analysis/AffineStructuresParserTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/AffineStructuresParserTest.cpp @@ -0,0 +1,116 @@ +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/PresburgerSet.h" + +#include + +namespace mlir { + +/// Construct a FlatAffineConstraints from a set of inequality, equality, and +/// division onstraints. +static FlatAffineConstraints makeFACFromConstraints( + unsigned dims, unsigned syms, ArrayRef> ineqs, + ArrayRef> eqs = {}, + ArrayRef, int64_t>> divs = {}) { + FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + syms + 1, dims, + syms, 0); + for (const auto &div : divs) { + fac.addLocalFloorDiv(div.first, div.second); + } + for (const auto &eq : eqs) + fac.addEquality(eq); + for (const auto &ineq : ineqs) + fac.addInequality(ineq); + return fac; +} + +TEST(ParseFACTest, InvalidInputTest) { + MLIRContext context; + FailureOr fac; + + fac = parseFlatAffineConstraints("(x)", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings with no constraint list"; + + fac = parseFlatAffineConstraints("(x)[] : ())", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings that contain remaining characters"; + + fac = parseFlatAffineConstraints("(x)[] : (x - >= 0)", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings that contain incomplete constraints"; + + fac = parseFlatAffineConstraints("(x)[] : (y == 0)", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings that contain unkown identifiers"; + + fac = parseFlatAffineConstraints("(x, x) : (2 * x >= 0)", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings that contain repeated identifier names"; + + fac = parseFlatAffineConstraints("(x)[x] : (2 * x >= 0)", &context); + EXPECT_TRUE(failed(fac)) + << "should not accept strings that contain repeated identifier names"; + + fac = parseFlatAffineConstraints("(x) : (2 * x + 9223372036854775808 >= 0)", + &context); + EXPECT_TRUE(failed(fac)) << "should not accept strings with integer literals " + "that do not fit into int64_t"; +} + +/// Parses and compares the `str` to the `ex`. The equality check is performed +/// by using PresburgerSet::isEqual +static bool parseAndCompare(StringRef str, FlatAffineConstraints ex) { + MLIRContext context; + FailureOr fac = + parseFlatAffineConstraints(str, &context); + + EXPECT_TRUE(succeeded(fac)); + + return PresburgerSet(*fac).isEqual(PresburgerSet(ex)); +} + +TEST(ParseFACTest, ParseAndCompareTest) { + // simple ineq + EXPECT_TRUE(parseAndCompare("(x)[] : (x >= 0)", + makeFACFromConstraints(1, 0, {{1, 0}}))); + + // simple eq + EXPECT_TRUE(parseAndCompare("(x)[] : (x == 0)", + makeFACFromConstraints(1, 0, {}, {{1, 0}}))); + + // multiple constraints + EXPECT_TRUE(parseAndCompare("(x)[] : (7 * x >= 0, -7 * x + 5 >= 0)", + makeFACFromConstraints(1, 0, {{7, 0}, {-7, 5}}))); + + // multiple dimensions + EXPECT_TRUE(parseAndCompare("(x,y,z)[] : (x + y - z >= 0)", + makeFACFromConstraints(3, 0, {{1, 1, -1, 0}}))); + + // dimensions and symbols + EXPECT_TRUE( + parseAndCompare("(x,y,z)[a,b] : (x + y - z + 2 * a - 15 * b >= 0)", + makeFACFromConstraints(3, 2, {{1, 1, -1, 2, -15, 0}}))); + + // only symbols + EXPECT_TRUE(parseAndCompare("()[a] : (2 * a - 4 == 0)", + makeFACFromConstraints(0, 1, {}, {{2, -4}}))); + + // simple floordiv + EXPECT_TRUE(parseAndCompare( + "(x, y) : (y - 3 * ((x + y - 13) floordiv 3) - 42 == 0)", + makeFACFromConstraints(2, 0, {}, {{0, 1, -3, -42}}, {{{1, 1, -13}, 3}}))); + + // multiple floordiv + EXPECT_TRUE(parseAndCompare( + "(x, y) : (y - x floordiv 3 - y floordiv 2 == 0)", + makeFACFromConstraints(2, 0, {}, {{0, 1, -1, -1, 0}}, + {{{1, 0, 0}, 3}, {{0, 1, 0, 0}, 2}}))); + + // nested floordiv + EXPECT_TRUE(parseAndCompare( + "(x, y) : (y - (x + y floordiv 2) floordiv 3 == 0)", + makeFACFromConstraints(2, 0, {}, {{0, 1, 0, -1, 0}}, + {{{0, 1, 0}, 2}, {{1, 0, 1, 0}, 3}}))); +} + +} // namespace mlir diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -98,11 +98,24 @@ } while (std::next_permutation(perm.begin(), perm.end())); } +/// Parses a FlatAffineConstraints from a StringRef. It is expected that the +/// string represents a valid IntegerSet, otherwise it will violate a gtest +/// assertion. +static FlatAffineConstraints parseFAC(StringRef str) { + MLIRContext context; + FailureOr fac = + parseFlatAffineConstraints(str, &context); + + EXPECT_TRUE(succeeded(fac)); + + return *fac; +} + TEST(FlatAffineConstraintsTest, FindSampleTest) { // Bounded sets with only inequalities. // 0 <= 7x <= 5 - checkSample(true, makeFACFromConstraints(1, {{7, 0}, {-7, 5}}, {})); + checkSample(true, parseFAC("(x) : (7 * x >= 0, -7 * x + 5 >= 0)")); // 1 <= 5x and 5x <= 4 (no solution). checkSample(false, makeFACFromConstraints(1, {{5, -1}, {-5, 4}}, {})); diff --git a/mlir/unittests/Analysis/CMakeLists.txt b/mlir/unittests/Analysis/CMakeLists.txt --- a/mlir/unittests/Analysis/CMakeLists.txt +++ b/mlir/unittests/Analysis/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRAnalysisTests AffineStructuresTest.cpp + AffineStructuresParserTest.cpp LinearTransformTest.cpp PresburgerSetTest.cpp )