diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -246,6 +246,16 @@ IntegerSet parseIntegerSet(llvm::StringRef str, MLIRContext *context, bool printDiagnosticInfo = true); +/// This parses comma separated IntegerSets to an MLIR context if it was valid. +/// If not, an error message is emitted through a new SourceMgrDiagnosticHandler +/// constructed from a new SourceMgr with a single MemoryBuffer wrapping `str`. +/// If the passed `str` has additional tokens that were not part of the +/// IntegerSets, a failure is returned. Diagnostics are printed on failure if +/// `printDiagnosticInfo` is true. +SmallVector +parseMultipleIntegerSets(llvm::StringRef str, MLIRContext *context, + bool printDiagnosticInfo = true); + } // namespace mlir #endif // MLIR_PARSER_PARSER_H diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -53,6 +53,7 @@ AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); + ParseResult parseMultipleIntegerSets(SmallVectorImpl &unionSet); ParseResult parseAffineMapOfSSAIds(AffineMap &map, OpAsmParser::Delimiter delimiter); ParseResult parseAffineExprOfSSAIds(AffineExpr &expr); @@ -537,6 +538,30 @@ return failure(); } +/// Parses the comma-separated integer sets. +ParseResult +AffineParser::parseMultipleIntegerSets(SmallVectorImpl &unionSet) { + unsigned numDims = 0, numSymbols = 0; + + // List of dimensional and optional symbol identifiers. + if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { + return failure(); + } + + if (parseToken(Token::colon, "expected ':'")) + return failure(); + + IntegerSet set; + do { + if ((set = parseIntegerSetConstraints(numDims, numSymbols))) + unionSet.push_back(set); + else + return failure(); + } while (consumeIf(Token::comma)); + + return success(); +} + /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map, @@ -700,6 +725,13 @@ return success(); } +ParseResult +Parser::parseMultipleIntegerSetsReference(SmallVectorImpl &set) { + if (AffineParser(state).parseMultipleIntegerSets(set)) + return failure(); + return success(); +} + /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to /// parse SSA value uses encountered while parsing affine expressions. ParseResult @@ -744,3 +776,29 @@ return set; } + +SmallVector +mlir::parseMultipleIntegerSets(StringRef inputStr, MLIRContext *context, + bool printDiagnosticInfo) { + llvm::SourceMgr sourceMgr; + auto memBuffer = llvm::MemoryBuffer::getMemBuffer( + inputStr, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + SymbolState symbolState; + ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); + Parser parser(state); + + raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls(); + SourceMgrDiagnosticHandler handler(sourceMgr, context, os); + SmallVector set; + if (parser.parseMultipleIntegerSetsReference(set)) + return SmallVector(); + + Token endTok = parser.getToken(); + if (endTok.isNot(Token::eof)) { + parser.emitError(endTok.getLoc(), "encountered unexpected token"); + return SmallVector(); + } + return set; +} diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -293,6 +293,8 @@ IntegerSet &set); ParseResult parseAffineMapReference(AffineMap &map); ParseResult parseIntegerSetReference(IntegerSet &set); + ParseResult + parseMultipleIntegerSetsReference(SmallVectorImpl &set); /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. ParseResult diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -97,9 +97,8 @@ } TEST(SetTest, containsPoint) { - PresburgerSet setA = parsePresburgerSetFromPolyStrings( - 1, - {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"}); + PresburgerSet setA = parsePresburgerSet( + "(x) : (x - 2 >= 0, -x + 8 >= 0), (x - 10 >= 0, -x + 20 >= 0)"); for (unsigned x = 0; x <= 21; ++x) { if ((2 <= x && x <= 8) || (10 <= x && x <= 20)) EXPECT_TRUE(setA.containsPoint({x})); @@ -109,10 +108,10 @@ // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union // a square with opposite corners (2, 2) and (10, 10). - PresburgerSet setB = parsePresburgerSetFromPolyStrings( - 2, {"(x,y) : (x + y - 4 >= 0, -x - y + 32 >= 0, " - "x - y - 2 >= 0, -x + y + 16 >= 0)", - "(x,y) : (x - 2 >= 0, y - 2 >= 0, -x + 10 >= 0, -y + 10 >= 0)"}); + PresburgerSet setB = parsePresburgerSet( + "(x,y) : (x + y - 4 >= 0, -x - y + 32 >= 0, " + "x - y - 2 >= 0, -x + y + 16 >= 0)," + "(x - 2 >= 0, y - 2 >= 0, -x + 10 >= 0, -y + 10 >= 0)"); for (unsigned x = 1; x <= 25; ++x) { for (unsigned y = -6; y <= 16; ++y) { @@ -127,9 +126,8 @@ } TEST(SetTest, Union) { - PresburgerSet set = parsePresburgerSetFromPolyStrings( - 1, - {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"}); + PresburgerSet set = parsePresburgerSet( + "(x) : (x - 2 >= 0, -x + 8 >= 0), (x - 10 >= 0, -x + 20 >= 0)"); // Universe union set. testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)), diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -35,6 +35,26 @@ return *poly; } +/// Parses a list of comma separated IntegerSets to IntegerPolyhedron and +/// combine them into a PresburgerSet by using the union operation. It is +/// expected that the string has valid comma separated IntegerSet constraints +/// and that all of them have the same number of dimensions as is specified by +/// the numDims argument. +inline PresburgerSet parsePresburgerSet(StringRef str) { + MLIRContext context(MLIRContext::Threading::DISABLED); + FailureOr> facs = + parseMultipleIntegerSetsToFAC(str, &context); + SmallVector ips; + EXPECT_TRUE(succeeded(facs)); + for (auto fac : facs.getValue()) + ips.push_back(IntegerPolyhedron(fac)); + + PresburgerSet set = PresburgerSet(ips.front()); + for (int i = 1, m = facs.getValue().size(); i < m; i++) + set.unionInPlace(ips[i]); + return set; +} + /// Parse a list of StringRefs to IntegerRelation and combine them into a /// PresburgerSet be using the union operation. It is expected that the strings /// are all valid IntegerSet representation and that all of them have the same diff --git a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.h b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.h --- a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.h +++ b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.h @@ -29,6 +29,10 @@ parseIntegerSetToFAC(llvm::StringRef, MLIRContext *context, bool printDiagnosticInfo = true); +FailureOr> +parseMultipleIntegerSetsToFAC(llvm::StringRef str, MLIRContext *context, + bool printDiagnosticInfo = true); + } // namespace mlir #endif // MLIR_UNITTEST_ANALYSIS_AFFINESTRUCTURESPARSER_H diff --git a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.cpp b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.cpp --- a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.cpp +++ b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParser.cpp @@ -23,3 +23,19 @@ return FlatAffineValueConstraints(set); } + +FailureOr> +mlir::parseMultipleIntegerSetsToFAC(llvm::StringRef str, MLIRContext *context, + bool printDiagnosticInfo) { + SmallVector set = + parseMultipleIntegerSets(str, context, printDiagnosticInfo); + + if (set.empty()) + return failure(); + + SmallVector ret; + for (auto iSet : set) + ret.push_back(FlatAffineConstraints(iSet)); + + return ret; +} diff --git a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParserTest.cpp b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParserTest.cpp --- a/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParserTest.cpp +++ b/mlir/unittests/Dialect/Affine/Analysis/AffineStructuresParserTest.cpp @@ -83,6 +83,24 @@ return PresburgerSet(*fac).isEqual(PresburgerSet(ex)); } +static bool parseAndCompare(StringRef str, + SmallVectorImpl &ex, + MLIRContext *context) { + FailureOr> fac = + parseMultipleIntegerSetsToFAC(str, context); + + EXPECT_TRUE(succeeded(fac)); + auto facVec = fac.getValue(); + unsigned m = facVec.size(); + + EXPECT_TRUE(m == ex.size()); + + for (unsigned i = 0; i < m; i++) + if (!PresburgerSet(facVec[i]).isEqual(PresburgerSet(ex[i]))) + return false; + return true; +} + TEST(ParseFACTest, ParseAndCompareTest) { MLIRContext context; // simple ineq @@ -134,3 +152,28 @@ {{{0, 1, 0}, 2}, {{1, 0, 1, 0}, 3}}), &context)); } + +TEST(ParseMultipleFACTest, ParseAndCompareTest) { + MLIRContext context; + + // Parses a simple ineq and eq respectively. + SmallVector actualFac{ + makeFACFromConstraints(1, 0, {{1, 0}}), + makeFACFromConstraints(1, 0, {}, {{1, 0}})}; + EXPECT_TRUE( + parseAndCompare("(x)[] : (x >= 0), (x == 0)", actualFac, &context)); + + // simple floordiv, multiple floordiv and nested floordiv together. + SmallVector divActualFac{ + makeFACFromConstraints(2, 0, {}, {{0, 1, -3, -42}}, {{{1, 1, -13}, 3}}), + makeFACFromConstraints(2, 0, {}, {{0, 1, -1, -1, 0}}, + {{{1, 0, 0}, 3}, {{0, 1, 0, 0}, 2}}), + makeFACFromConstraints(2, 0, {}, {{0, 1, 0, -1, 0}}, + {{{0, 1, 0}, 2}, {{1, 0, 1, 0}, 3}})}; + + EXPECT_TRUE( + parseAndCompare("(x, y) : (y - 3 * ((x + y - 13) floordiv 3) - 42 == " + "0), (y - x floordiv 3 - y floordiv 2 == 0), (y - (x + " + "y floordiv 2) floordiv 3 == 0)", + divActualFac, &context)); +}