diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -10,9 +10,18 @@ // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // // IMPL: ArrayAttr Test1Op::indexing_maps() { -// IMPL: AffineMap::get(2, 0, {d0, d1}, context), -// IMPL-NEXT: AffineMap::get(2, 0, {d1}, context), -// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }); +// IMPL: auto s0 = getAffineSymbolExpr(0, context); (void)s0; +// IMPL-NEXT: auto s1 = getAffineSymbolExpr(1, context); (void)s1; +// IMPL-NEXT: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); +// IMPL-NEXT: map0 = map0.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); +// IMPL-NEXT: map0 = simplifyAffineMap(map0); +// IMPL-NEXT: auto map1 = AffineMap::get(2, 2, {d1}, context); +// IMPL-NEXT: map1 = map1.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); +// IMPL-NEXT: map1 = simplifyAffineMap(map1); +// IMPL-NEXT: auto map2 = AffineMap::get(2, 2, {d0}, context); +// IMPL-NEXT: map2 = map2.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); +// IMPL-NEXT: map2 = simplifyAffineMap(map2); +// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // // IMPL: void Test1Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); @@ -34,9 +43,9 @@ // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // // IMPL: ArrayAttr Test2Op::indexing_maps() { -// IMPL: AffineMap::get(3, 0, {d0, d2}, context), -// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context), -// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }); +// IMPL: AffineMap::get(3, 3, {d0, d2}, context) +// IMPL: AffineMap::get(3, 3, {d2, d1}, context) +// IMPL: AffineMap::get(3, 3, {d0, d1}, context) // // IMPL: Test2Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); @@ -58,9 +67,9 @@ // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // // IMPL: ArrayAttr Test3Op::indexing_maps() { -// IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context), -// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context), -// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }); +// IMPL: AffineMap::get(4, 4, {d0, d1, d3}, context) +// IMPL: AffineMap::get(4, 4, {d3, d2}, context) +// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // // IMPL: Test3Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); @@ -94,3 +103,25 @@ ) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } + +// Test attribute usage in affine expressions +// IMPL-LABEL: ArrayAttr Test5Op::indexing_maps() { +// IMPL: auto cst0 = getAffineConstantExpr(strides().getValue({ 0 }), context); +// IMPL: auto cst1 = getAffineConstantExpr(strides().getValue({ 1 }), context); +// IMPL: auto map0 = AffineMap::get(7, 9, {d0, d1 * s7 + d4, d2 * s8 + d5, d6}, context); +// IMPL: map0 = map0.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map0 = simplifyAffineMap(map0); +// IMPL: auto map1 = AffineMap::get(7, 9, {d3, d4, d5, d6}, context); +// IMPL: map1 = map1.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map1 = simplifyAffineMap(map1); +// IMPL: auto map2 = AffineMap::get(7, 7, {d0, d1, d2, d3}, context); +// IMPL: map2 = map2.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); +// IMPL: map2 = simplifyAffineMap(map2); +// IMPL: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); +// +ods_def: +def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) + attr(strides: 2xi32) { + O(n, h, w, f) = std_addf(std_mulf( + I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c))); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -19,8 +19,12 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" @@ -366,6 +370,14 @@ // Lexer Utilities //===--------------------------------------------------------------------===// + LogicalResult parseInteger(uint64_t &value) { + if (!curToken.is(Token::Kind::integer)) + emitError(curToken.getLoc(), "expected integer"); + value = curToken.getUInt64IntegerValue().getValue(); + consumeToken(); + return success(); + } + /// Advance the current lexer onto the next token. void consumeToken() { assert(curToken.getKind() != Token::Kind::eof && @@ -447,6 +459,30 @@ }; } // namespace +/// Encodes an attribute use of the form: +/// +/// index-list ::= integer-literal (`,` integer-literal)* +/// attr-use ::= bare-id `[` index-list `]` +struct AttrUse { + // Referenced attribute + StringRef attrName; + // Indices into the attribute + SmallVector indices; + /// Affine symbol for this usage. + /// This is represented as an affine symbol because at the time of parsing the + /// spec and generating the op's ODS/C++, we don't know the concrete constant + /// value. But they should be replaced with constants red from the attribute + /// and thus folded away for concrete op instances. + AffineExpr symbol; + + std::string getKey() { + SmallVector indexStrs; + for (uint64_t index : indices) + indexStrs.push_back(std::to_string(index)); + return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ",")); + } +}; + //===----------------------------------------------------------------------===// // Affine parsing. //===----------------------------------------------------------------------===// @@ -479,10 +515,20 @@ /// This is a specialized parser for affine expressions. class AffineParser { public: - explicit AffineParser(Parser &p, - std::function bareIdParsingHook, - AffineDimList &dimList, AffineSymbolList &symbolList) - : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList), + /// Creates an affine parser that parses tokens from `p`. + /// + /// Upon encountering a new identifier in the token stream, this parser will + /// first invoke `attrUseParsingHook` to try to parse it as an attribute use. + /// If unsuccessful (seeing llvm::None), then query `dimList` and `symbolList` + /// to get affine expressions for known identifiers. If that's still + /// unsuccessful, then invoke `bareIdParsingHook` to create a new affine + /// symbol/dimension expression. + explicit AffineParser( + Parser &p, std::function bareIdParsingHook, + std::function()> attrUseParsingHook, + AffineDimList &dimList, AffineSymbolList &symbolList) + : parser(p), bareIdFallback(bareIdParsingHook), + attrUseCallback(attrUseParsingHook), dims(dimList), symbols(symbolList) {} /// Parse a comma-separated list of affine exprs. @@ -502,6 +548,7 @@ AffineExpr parseParentheticalExpr(); AffineExpr parseNegateExpression(AffineExpr lhs); AffineExpr parseIntegerExpr(); + AffineExpr parseAttrUseOrBareIdExpr(); AffineExpr parseBareIdExpr(); AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, @@ -515,6 +562,7 @@ Parser &parser; std::function bareIdFallback; + std::function()> attrUseCallback; AffineDimList &dims; AffineSymbolList &symbols; }; @@ -688,6 +736,12 @@ return (-1) * operand; } +AffineExpr AffineParser::parseAttrUseOrBareIdExpr() { + if (llvm::Optional attrUse = attrUseCallback()) + return attrUse.getValue(); + return parseBareIdExpr(); +} + /// Parse a bare id that may appear in an affine expression. /// /// affine-expr ::= bare-id @@ -739,7 +793,7 @@ AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { switch (parser.curToken.getKind()) { case Token::Kind::id: - return parseBareIdExpr(); + return parseAttrUseOrBareIdExpr(); case Token::Kind::integer: return parseIntegerExpr(); case Token::Kind::l_paren: @@ -994,8 +1048,12 @@ LogicalResult parseTensorUse(TensorUse &result, ComprehensionParsingState &state); + /// Parses an attribute definition. LogicalResult parseAttrDef(); + /// Parses an optional attribute use. + LogicalResult parseAttrUse(AttrUse &result); + /// Parses a tensor expression. LogicalResult parseExpression(TensorUse currentDefinition, std::unique_ptr &result, @@ -1053,6 +1111,10 @@ SmallVector vectorDims; bool isArray; bool isOptional; + + // Returns the function to get values at the given indices from this + // attribute. + std::string getValueFn(ArrayRef indices) const; }; //===--------------------------------------------------------------------===// @@ -1061,6 +1123,9 @@ /// Symbols are per TC def. AffineSymbolList symbols; + /// Attribute usages in all affine expressions. + SmallVector attrUses; + /// Tensors are per TC def. llvm::StringMap registeredTensors; unsigned nextRegisteredTensorIndex; @@ -1147,20 +1212,45 @@ TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims, Token::Kind lDelim, Token::Kind rDelim) { - AffineParser affineParser( - parser, - [&](StringRef sRef) { - AffineExpr expr; - if (discoveryMode == EagerDiscoveryMode::Symbols) { - expr = getAffineSymbolExpr(symbols.size(), parser.context); - symbols.emplace_back(sRef, expr); - } else if (discoveryMode == EagerDiscoveryMode::Dimensions) { - expr = getAffineDimExpr(dims.size(), parser.context); - dims.emplace_back(sRef, expr); - } - return expr; - }, - dims, symbols); + auto createAffineBareId = [&](StringRef sRef) { + AffineExpr expr; + if (discoveryMode == EagerDiscoveryMode::Symbols) { + expr = getAffineSymbolExpr(symbols.size(), parser.context); + symbols.emplace_back(sRef, expr); + } else if (discoveryMode == EagerDiscoveryMode::Dimensions) { + expr = getAffineDimExpr(dims.size(), parser.context); + dims.emplace_back(sRef, expr); + } + return expr; + }; + + auto tryToParseAttrUse = [&]() -> llvm::Optional { + if (!parser.curToken.is(Token::Kind::id)) + return llvm::None; + + StringRef attrName = parser.curToken.getSpelling(); + auto it = registeredAttrs.find(attrName.str()); + if (it == registeredAttrs.end()) + return llvm::None; + + AttrUse result; + if (failed(parseAttrUse(result))) + return llvm::None; + + // We create a new symbol for each attribute usage without reuse. This is + // fine given these symbols will be replaced with constants and folded away + // for concrete op instances. + result.symbol = getAffineSymbolExpr(symbols.size(), parser.context); + // Merely for taking the index. We don't reuse anyway. + symbols.emplace_back("", result.symbol); + + attrUses.push_back(result); + + return result.symbol; + }; + + AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims, + symbols); return affineParser.parseAffineExprs(lDelim, rDelim); } @@ -1241,8 +1331,9 @@ // Parse potential dimension list SmallVector vectorDims; while (parser.curToken.is(Token::Kind::integer)) { - vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue()); - parser.consumeToken(); + uint64_t value; + parser.parseInteger(value); + vectorDims.push_back(value); StringRef spelling = parser.curToken.getSpelling(); if (spelling[0] != 'x') @@ -1286,6 +1377,44 @@ return success(); } +LogicalResult TCParser::parseAttrUse(AttrUse &result) { + result.attrName = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) + return failure(); + + auto it = registeredAttrs.find(result.attrName.str()); + assert(it != registeredAttrs.end()); + const RegisteredAttr &attr = it->second; + + if (!attr.vectorDims.empty() || attr.isArray) { + // This is a vector/array attribute. Parse indices for it. + auto indexLoc = parser.curToken.getLoc(); + + if (failed(parser.parseToken(Token::Kind::l_square, "expected '['"))) + return failure(); + + auto parseIndex = [&]() { + uint64_t value; + if (failed(parser.parseInteger(value))) + return failure(); + result.indices.push_back(value); + return success(); + }; + if (failed(parser.parseCommaSeparatedListUntil( + Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false))) + return failure(); + + size_t rank = attr.isArray ? 1 : attr.vectorDims.size(); + if (result.indices.size() != rank) + return parser.emitError(indexLoc, + "number of indices mismatch: expected " + + std::to_string(rank) + ", but found " + + std::to_string(result.indices.size())); + } + + return success(); +} + /// Parses a tensor expression of the form: /// /// op-spec ::= bare-id `<` reduction-dims-list `>` @@ -1776,7 +1905,8 @@ MLIRContext *context = getContext(); AffineExpr {1}; bindDims(context, {1}); - return Builder(context).getAffineMapArrayAttr({ {2} }); + {2} + return Builder(context).getAffineMapArrayAttr({ {3} }); })FMT"; // 2. Print a comma-separated list of identifiers for the AffineExpr in @@ -1790,36 +1920,89 @@ [&](std::pair p) { ss << p.second; }); ss.flush(); - // 3. Print a comma-separated list of AffineMap constructors that use the - // identifiers from 1. The AffineExpr use the common arithmetic operators on - // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder - // in return `SmallVector{{ {2} };`. + // 3. Get the list of affine maps for each input/output. The AffineExpr use + // the common arithmetic operators on AffineExpr. These affine maps will + // replace the `{2}` placeholder. std::string mapsStr; llvm::raw_string_ostream mapsStringStream(mapsStr); + SmallVector orderedUses(state.orderedTensorArgs.size()); for (const auto &it : state.orderedTensorArgs) orderedUses[it.second] = it.first; - llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { - assert(u.indexingMap); - const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)"; - if (u.indexingMap.isEmpty()) { - mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context"); + + // Create a list of all symbols. + SmallVector symbolReplacements; + symbolReplacements.reserve(symbols.size()); + for (unsigned i = 0; i < symbols.size(); ++i) { + const char *symFmt = + "\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};"; + mapsStringStream << llvm::formatv(symFmt, i); + symbolReplacements.push_back(llvm::formatv("s{0}", i)); + } + + // Create the affine constant expressions to replace symbols for attributes. + for (auto attrUse : llvm::enumerate(attrUses)) { + StringRef attrName = attrUse.value().attrName; + auto it = registeredAttrs.find(attrName.str()); + assert(it != registeredAttrs.end()); + std::string getValueFn = it->second.getValueFn(attrUse.value().indices); + if (getValueFn.empty()) { + parser.emitError("unimplemented getValueFn for attribute: " + attrName); return; } + std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn); + const char *cstFmt = + "\n\tauto cst{0} = getAffineConstantExpr({1}, context);"; + mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal); + + unsigned position = + attrUse.value().symbol.cast().getPosition(); + symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index()); + } + + // For each tensor use, construct the affine map, replace symbols for + // attributes, and simplify the affine map. + for (auto tensorUse : llvm::enumerate(orderedUses)) { + auto indexingMap = tensorUse.value().indexingMap; + const char *mapFmt = + "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);"; std::string exprsStr; llvm::raw_string_ostream exprsStringStream(exprsStr); exprsStringStream << "{"; - llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream); + llvm::interleaveComma(indexingMap.getResults(), exprsStringStream); exprsStringStream << "}"; exprsStringStream.flush(); + mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(), + state.dims.size(), + indexingMap.getNumSymbols(), exprsStr); + + std::string replaceSymbolList = + llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", ")); + + // Note that we use `0` as the result affine map's number of symbols. All + // symbols representing attribute usages should be folded away. But there + // may exist additional symbols for tensor dimension upper bounds. Linalg + // does not handle such cases right now. This needs to be fixed once we need + // that. + const char *replaceFmt = + "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);"; + mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(), + replaceSymbolList, state.dims.size()); + const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});"; + mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index()); + } - mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr); - }); mapsStringStream.flush(); + SmallVector mapList; + mapList.reserve(orderedUses.size()); + for (unsigned i = 0; i < orderedUses.size(); ++i) + mapList.push_back(llvm::formatv("map{0}", i)); + // 4. Apply format to 1. using 2. and 3. - os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr); + os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr, + llvm::join(mapList, ", ")); } /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. @@ -1893,6 +2076,31 @@ expressionsStr, yieldStr); } +std::string +TCParser::RegisteredAttr::getValueFn(ArrayRef indices) const { + if (isArray) + return ""; + + if (!vectorDims.empty()) { + SmallVector indexStrs; + for (uint64_t index : indices) + indexStrs.push_back(std::to_string(index)); + std::string indexList = llvm::join(indexStrs, ", "); + if (elementType == "f32") + return llvm::formatv("getValue({ {0} })", indexList); + if (elementType == "i32") + return llvm::formatv("getValue({ {0} })", indexList); + + return ""; + } + + if (elementType == "f32") + return "getValue().convertToFloat()"; + if (elementType == "i32") + return "getInt()"; + return ""; +} + /// Iterate over each Tensor Comprehension def. LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os, Parser &parser) {