diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -582,8 +582,9 @@ resorting to more general MLIR parsing. 1. Reduction dimensions are specified with angle bracket notation on the operation they apply to (e.g. `std_add` specifies that `k` is a reduction - dimension). In TC, a reduction is specified with `op=` operator and the - reduction dimensions are inferred. + dimension). In TC, the reduction dimensions are inferred. If one of the + operand is not used in any expressions, it will be considered a shape-only + operand, and the result of the indexing_map will be reduction dimensions. 1. The parallel and reduction dimension are ordered by the textual program order. For instance, in the comprehension `O(i, j) = std_add(...)`, `i` (resp. `j`) is a parallel iterator encoded by affine dimension of diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -190,3 +190,14 @@ { C(m) = std_subf(std_mulf(A(m, k), B(k)), C(m)); } + +// Test shape-only operand. +// IMPL-LABEL: ArrayAttr Test9Op::indexing_maps() { +// IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); +// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context); +// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context); +ods_def: +def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M)) +{ + C(m) = std_addf(C(m), A(m, k)); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1634,7 +1634,26 @@ tensor.indexingMap = use.indexingMap; state.orderedTensorArgs[use] = tensor.index; }); - state.numArgs = seenDefs.size(); + // If more than one definitions are less. They are shaped-only operand, which + // are used to define reduction loops. For now, only accept exactly one + // shaped-only operand. + if (state.numArgs > seenDefs.size() + 1) { + failed = true; + } else if (state.numArgs == seenDefs.size() + 1) { + for (auto &tensorIter : registeredTensors) { + auto &tensor = tensorIter.getValue(); + if (tensor.indexingMap) + continue; + if (auto *pTensorExpr = + dyn_cast(state.expressions[0].get())) { + SmallVector exprs; + for (auto dim : pTensorExpr->reductionDimensions) + exprs.push_back(getAffineDimExpr(dim, parser.context)); + tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(), + exprs, parser.context); + } + } + } if (failed) return failure(); @@ -1762,6 +1781,7 @@ SmallVector perComprehensionStates; while (parser.curToken.isNot(Token::Kind::r_brace)) { perComprehensionStates.push_back(ComprehensionParsingState()); + perComprehensionStates.back().numArgs = registeredTensors.size(); if (failed(parseOneComprehension(cppOpName, tcName, perComprehensionStates.back()))) return failure(); @@ -2207,10 +2227,6 @@ std::string mapsStr; llvm::raw_string_ostream mapsStringStream(mapsStr); - SmallVector orderedUses(state.numArgs); - for (const auto &it : state.orderedTensorArgs) - orderedUses[it.second] = it.first; - // Create a list of all symbols. SmallVector symbolReplacements; symbolReplacements.reserve(symbols.size()); @@ -2242,10 +2258,11 @@ symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index()); } - // For each tensor use, construct the affine map, replace symbols by the - // corresponding attribute values, and simplify the affine map. - for (auto tensorUse : llvm::enumerate(orderedUses)) { - auto indexingMap = tensorUse.value().indexingMap; + // For each registered tensor, construct the affine map, replace symbols by + // the corresponding attribute values, and simplify the affine map. + for (auto &tensorIter : registeredTensors) { + auto &tensor = tensorIter.getValue(); + auto indexingMap = tensor.indexingMap; const char *mapFmt = "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);"; @@ -2255,8 +2272,7 @@ llvm::interleaveComma(indexingMap.getResults(), exprsStringStream); exprsStringStream << "}"; exprsStringStream.flush(); - mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(), - state.dims.size(), + mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(), indexingMap.getNumSymbols(), exprsStr); std::string replaceSymbolList = @@ -2269,17 +2285,17 @@ // need that. const char *replaceFmt = "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);"; - mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(), + mapsStringStream << llvm::formatv(replaceFmt, tensor.index, replaceSymbolList, state.dims.size()); const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});"; - mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index()); + mapsStringStream << llvm::formatv(simplifyFmt, tensor.index); } mapsStringStream.flush(); SmallVector mapList; - mapList.reserve(orderedUses.size()); - for (unsigned i = 0; i < orderedUses.size(); ++i) + mapList.reserve(state.numArgs); + for (auto i : llvm::seq(0, state.numArgs)) mapList.push_back(llvm::formatv("map{0}", i)); // 4. Apply format to 1. using 2. and 3.