diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md --- a/mlir/docs/Dialects/Affine.md +++ b/mlir/docs/Dialects/Affine.md @@ -91,7 +91,8 @@ | bare-id | `-`? integer-literal -multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` +multi-dim-affine-expr ::= `(` `)` + | `(` affine-expr (`,` affine-expr)* `)` ``` `ceildiv` is the ceiling function which maps the result of the division of its diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -184,7 +184,9 @@ MLIRContext *context = getContext(); auto r_i = getAffineDimExpr(0, context); return SmallVector{ - AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()}; + AffineMap::get(1, 0, {r_i}), + AffineMap::get(1, 0, {r_i}), + AffineMap::get(1, 0, context)}; } }]; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -44,6 +44,11 @@ /// Returns a zero result affine map with no dimensions or symbols: () -> (). static AffineMap get(MLIRContext *context); + /// Returns a zero result affine map with `dimCount` dimensions and + /// `symbolCount` symbols, e.g.: `(...) -> ()`. + static AffineMap get(unsigned dimCount, unsigned symbolCount, + MLIRContext *context); + static AffineMap get(unsigned dimCount, unsigned symbolCount, ArrayRef results); @@ -275,8 +280,7 @@ namespace llvm { // AffineExpr hash just like pointers -template <> -struct DenseMapInfo { +template <> struct DenseMapInfo { static mlir::AffineMap getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::AffineMap(static_cast(pointer)); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -356,15 +356,9 @@ << idx << " to have " << nLoops << " dim(s) to match the number of loops"; - if (m.getNumResults() == 1 && view.getRank() == 0) { - auto cst = m.getResult(0).template dyn_cast(); - if (!cst || cst.getValue() != 0) - return op.emitOpError("expected indexing_map #") - << idx << " to be 0 to match 0-D view: " << view; - } else if (m.getNumResults() != view.getRank()) { + if (m.getNumResults() != view.getRank()) return op.emitOpError("expected indexing_map #") << idx << " results to match view rank: " << view; - } } auto concatMap = concatAffineMaps(indexingMaps); @@ -886,7 +880,7 @@ if (maybeMap) return maybeMap.getValue(); if (rank == 0) - return AffineMap(); + return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -37,6 +37,8 @@ static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef vals) { + if (map.isEmpty()) + return {}; assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -241,26 +243,17 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - Value input = genericOp.getInput(i); - if (input.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs)); - indexedValues[i] = std_load(input, indexing); - } else { - indexedValues[i] = std_load(input); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); + indexedValues[i] = std_load(genericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(output, indexing); - } else { - indexedValues[nInputs + i] = std_load(output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nInputs + i] = std_load(output, indexing); } auto funcOp = genericOp.getFunction(); @@ -272,13 +265,9 @@ // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), output, indexing); - } else { - std_store(callOp->getResult(i), output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), output, indexing); } return; } @@ -297,15 +286,10 @@ auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), - genericOp.getOutputBuffer(i), indexing); - } else { - std_store(map.lookup(yieldOp->getOperand(i)), output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -281,7 +281,8 @@ exprs.reserve(getResults().size()); for (auto expr : getResults()) exprs.push_back(expr.compose(newMap)); - return AffineMap::get(numDims, numSymbols, exprs); + return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext()) + : AffineMap::get(numDims, numSymbols, exprs); } bool AffineMap::isProjectedPermutation() { @@ -325,7 +326,7 @@ } AffineMap mlir::inversePermutation(AffineMap map) { - if (!map) + if (map.isEmpty()) return map; assert(map.getNumSymbols() == 0 && "expected map without symbols"); SmallVector exprs(map.getNumDims()); @@ -351,18 +352,18 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { unsigned numResults = 0; for (auto m : maps) - numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0; + numResults += m.getNumResults(); unsigned numDims = 0; SmallVector results; results.reserve(numResults); for (auto m : maps) { - if (!m || m.isSingleConstant()) - continue; assert(m.getNumSymbols() == 0 && "expected map without symbols"); results.append(m.getResults().begin(), m.getResults().end()); numDims = std::max(m.getNumDims(), numDims); } - return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results); + return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0, + maps.front().getContext()) + : AffineMap::get(numDims, /*numSymbols=*/0, results); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -612,6 +612,11 @@ } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, + MLIRContext *context) { + return getImpl(dimCount, /*symbolCount=*/0, /*results=*/{}, context); +} + +AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results) { // The number of results can't be zero. assert(!results.empty()); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3068,14 +3068,16 @@ }; // Parse a multi-dimensional affine expression (a comma-separated list of - // 1-d affine expressions); the list cannot be empty. Grammar: - // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) + // 1-d affine expressions); the list can be empty. Grammar: + // multi-dim-affine-expr ::= `(` `)` + // | `(` affine-expr (`,` affine-expr)* `)` if (parseCommaSeparatedListUntil(rightToken, parseElt, /*allowEmptyList=*/true)) return failure(); // Parsed a valid affine map. if (exprs.empty()) - map = AffineMap::get(getContext()); + map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, + getContext()); else map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, exprs); @@ -3101,13 +3103,14 @@ }; // Parse a multi-dimensional affine expression (a comma-separated list of - // 1-d affine expressions); the list cannot be empty. Grammar: - // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) + // 1-d affine expressions). Grammar: + // multi-dim-affine-expr ::= `(` `)` + // | `(` affine-expr (`,` affine-expr)* `)` if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) return AffineMap(); if (exprs.empty()) - return AffineMap::get(getContext()); + return AffineMap::get(numDims, numSymbols, getContext()); // Parsed a valid affine map. return AffineMap::get(numDims, numSymbols, exprs); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -170,7 +170,7 @@ func @foo(%0: i32) -> i32 { return %0: i32 } -func @generic_wrong_dim_in_map(%arg0: memref) { +func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { args_in = 0, @@ -178,22 +178,7 @@ fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = ["parallel"] - } %arg0: memref -} - -// ----- - -func @foo(%0: i32) -> i32 { return %0: i32 } - -func @generic_zero_d_view(%arg0: memref) { - // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref'}} - linalg.generic { - args_in = 0, - args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (1)> ], - iterator_types = [] - } %arg0: memref + } %arg0: memref<1xi32> } // ----- diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -360,7 +360,7 @@ // ----- #broadcast_access = [ - affine_map<(i, j) -> (0)>, + affine_map<(i, j) -> ()>, affine_map<(i, j) -> (i, j)> ] @@ -414,7 +414,7 @@ #reduce_1D_access = [ affine_map<(i) -> (i)>, - affine_map<(i) -> (0)> + affine_map<(i) -> ()> ] #trait_reduce_1D = { @@ -446,8 +446,8 @@ #reduce_init_1D_access = [ affine_map<(i) -> (i)>, - affine_map<(i) -> (0)>, - affine_map<(i) -> (0)> + affine_map<(i) -> ()>, + affine_map<(i) -> ()> ] #trait_reduce_init_1D = { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -346,7 +346,7 @@ // ----- #broadcast_access = [ - affine_map<(i, j) -> (0)>, + affine_map<(i, j) -> ()>, affine_map<(i, j) -> (i, j)> ]