diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -184,7 +184,9 @@ MLIRContext *context = getContext(); auto r_i = getAffineDimExpr(0, context); return SmallVector{ - AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()}; + AffineMap::get(1, 0, {r_i}), + AffineMap::get(1, 0, {r_i}), + AffineMap::get(context)}; } }]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -351,20 +351,14 @@ return op.emitOpError("expected indexing_map #") << idx << " to have no symbols"; - if (m.getNumDims() != nLoops) + if (!m.isEmpty() && m.getNumDims() != nLoops) return op.emitOpError("expected indexing_map #") << idx << " to have " << nLoops << " dim(s) to match the number of loops"; - if (m.getNumResults() == 1 && view.getRank() == 0) { - auto cst = m.getResult(0).template dyn_cast(); - if (!cst || cst.getValue() != 0) - return op.emitOpError("expected indexing_map #") - << idx << " to be 0 to match 0-D view: " << view; - } else if (m.getNumResults() != view.getRank()) { + if (m.getNumResults() != view.getRank()) return op.emitOpError("expected indexing_map #") << idx << " results to match view rank: " << view; - } } auto concatMap = concatAffineMaps(indexingMaps); @@ -886,7 +880,7 @@ if (maybeMap) return maybeMap.getValue(); if (rank == 0) - return AffineMap(); + return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -37,6 +37,8 @@ static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef vals) { + if (map.isEmpty()) + return {}; assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -241,26 +243,17 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - Value input = genericOp.getInput(i); - if (input.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs)); - indexedValues[i] = std_load(input, indexing); - } else { - indexedValues[i] = std_load(input); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); + indexedValues[i] = std_load(genericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(output, indexing); - } else { - indexedValues[nInputs + i] = std_load(output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nInputs + i] = std_load(output, indexing); } auto funcOp = genericOp.getFunction(); @@ -272,13 +265,9 @@ // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), output, indexing); - } else { - std_store(callOp->getResult(i), output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), output, indexing); } return; } @@ -297,15 +286,10 @@ auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - Value output = genericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), - genericOp.getOutputBuffer(i), indexing); - } else { - std_store(map.lookup(yieldOp->getOperand(i)), output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -325,7 +325,8 @@ } AffineMap mlir::inversePermutation(AffineMap map) { - if (!map) + assert(map); + if (map.isEmpty()) return map; assert(map.getNumSymbols() == 0 && "expected map without symbols"); SmallVector exprs(map.getNumDims()); @@ -349,20 +350,21 @@ } AffineMap mlir::concatAffineMaps(ArrayRef maps) { + assert(!maps.empty()); + assert(llvm::all_of(maps, [](AffineMap m) { return m; })); unsigned numResults = 0; for (auto m : maps) - numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0; + numResults += m.getNumResults(); unsigned numDims = 0; SmallVector results; results.reserve(numResults); for (auto m : maps) { - if (!m || m.isSingleConstant()) - continue; assert(m.getNumSymbols() == 0 && "expected map without symbols"); results.append(m.getResults().begin(), m.getResults().end()); numDims = std::max(m.getNumDims(), numDims); } - return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results); + return numDims == 0 ? AffineMap::get(maps.front().getContext()) + : AffineMap::get(numDims, 0, results); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -170,7 +170,7 @@ func @foo(%0: i32) -> i32 { return %0: i32 } -func @generic_wrong_dim_in_map(%arg0: memref) { +func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { args_in = 0, @@ -178,22 +178,7 @@ fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = ["parallel"] - } %arg0: memref -} - -// ----- - -func @foo(%0: i32) -> i32 { return %0: i32 } - -func @generic_zero_d_view(%arg0: memref) { - // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref'}} - linalg.generic { - args_in = 0, - args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (1)> ], - iterator_types = [] - } %arg0: memref + } %arg0: memref<1xi32> } // ----- diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -360,7 +360,7 @@ // ----- #broadcast_access = [ - affine_map<(i, j) -> (0)>, + affine_map<() -> ()>, affine_map<(i, j) -> (i, j)> ] @@ -414,7 +414,7 @@ #reduce_1D_access = [ affine_map<(i) -> (i)>, - affine_map<(i) -> (0)> + affine_map<() -> ()> ] #trait_reduce_1D = { @@ -446,8 +446,8 @@ #reduce_init_1D_access = [ affine_map<(i) -> (i)>, - affine_map<(i) -> (0)>, - affine_map<(i) -> (0)> + affine_map<() -> ()>, + affine_map<() -> ()> ] #trait_reduce_init_1D = { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -346,7 +346,7 @@ // ----- #broadcast_access = [ - affine_map<(i, j) -> (0)>, + affine_map<() -> ()>, affine_map<(i, j) -> (i, j)> ]