diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -672,6 +672,8 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + + let hasCanonicalizer = 1; } /// GenericOp with Indexing (i.e. multi-for style in which the region is passed diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -123,6 +123,10 @@ "Return the range over inputs (irrespective of type) and output buffers.", "Operation::operand_range", "getInputsAndOutputBuffers" >, + InterfaceMethod< + "Return the shaped types for all the inputs and outputs", + "SmallVector", "getInputOutputShapedTypes" + >, //===------------------------------------------------------------------===// // Other interface methods. @@ -153,6 +157,10 @@ "Return the indexing maps attribute within the current operation.", "ArrayAttr", "indexing_maps" >, + InterfaceMethod< + "Return the indexing maps within the current operation.", + "SmallVector", "getIndexingMaps" + >, InterfaceMethod<"Return the input or output indexing map at index `i`.", "AffineMap", "getIndexingMap", (ins "unsigned":$i) >, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -217,6 +217,18 @@ return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] .template cast(); } + /// Return the shaped types for all the inputs and outputs + SmallVector getInputOutputShapedTypes() { + SmallVector inputOutputTypes( + this->getOperation()->operand_type_begin(), + this->getOperation()->operand_type_end()); + inputOutputTypes.append(this->getOperation()->result_type_begin(), + this->getOperation()->result_type_end()); + return llvm::to_vector<4>( + llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { + return type.cast(); + })); + } //==========================================================================// // Other interface methods. @@ -295,6 +307,13 @@ return attr; } + SmallVector getIndexingMaps() { + return llvm::to_vector<4>( + llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + } + AffineMap getIndexingMap(unsigned i) { assert(i < getNumInputsAndOutputs()); return indexing_maps() diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -61,6 +61,121 @@ return success(folded); } +/// Given dims of the iteration space of a structured op that are known to be +/// single trip count (`unitDims`), return the indexing maps to use in the +/// canonicalized op with these dims removed, given the original `indexingMaps`. +static ArrayAttr replaceUnitDims(DenseSet &unitDims, + ArrayRef indexingMaps, + MLIRContext *context) { + if (indexingMaps.empty()) + return nullptr; + unsigned numIterationDims = indexingMaps.front().getNumDims(); + unsigned numSymbols = indexingMaps.front().getNumSymbols(); + + // Compute the replacement for each dim expr. + SmallVector dimReplacements; + dimReplacements.reserve(numIterationDims); + unsigned numKeptDims = 0; + for (unsigned dim : llvm::seq(0, numIterationDims)) { + if (unitDims.count(dim)) { + dimReplacements.push_back(getAffineConstantExpr(0, context)); + } else { + dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); + } + } + + // Symbols remain the same. + SmallVector symReplacements; + symReplacements.reserve(numSymbols); + for (unsigned symbol : llvm::seq(0, numSymbols)) + symReplacements.push_back(getAffineSymbolExpr(symbol, context)); + + SmallVector newIndexingMaps; + newIndexingMaps.reserve(indexingMaps.size()); + for (AffineMap operandMap : indexingMaps) { + // Expected indexing maps to have no symbols. + if (operandMap.getNumSymbols()) + return nullptr; + newIndexingMaps.push_back(simplifyAffineMap( + operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, + numIterationDims - unitDims.size(), + numSymbols))); + } + + // Check that the new index maps are invertible. If not, something went + // wrong, so abort. + if (!inversePermutation(concatAffineMaps(newIndexingMaps))) + return nullptr; + return ArrayAttr::get( + llvm::to_vector<4>(llvm::map_range( + newIndexingMaps, + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), + context); +} + +namespace { +/// Pattern to fold unit-trip count loops in GenericOps. +// TODO: Generalize this to indexed-generic as well by modifying the region args +// as well. +struct FoldUnitDimLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.empty()) + return failure(); + + // Check if any of the iteration dimensions are unit-trip count. They will + // end up being unit-trip count if they are used to index into a unit-dim + // tensor/memref. + AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + if (!invertedMap) + return failure(); + SmallVector dims; + for (ShapedType shapedType : genericOp.getInputOutputShapedTypes()) + dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); + DenseSet unitDims; + ArrayAttr iteratorTypes = genericOp.iterator_types(); + for (auto expr : enumerate(invertedMap.getResults())) { + if (AffineDimExpr dimExpr = expr.value().dyn_cast()) + if (dims[dimExpr.getPosition()] == 1 && + iteratorTypes[expr.index()].dyn_cast().getValue() == + getParallelIteratorTypeName()) + unitDims.insert(expr.index()); + } + if (unitDims.empty()) + return failure(); + + // Compute the modified indexing maps. + MLIRContext *context = rewriter.getContext(); + ArrayAttr newIndexingMapAttr = + replaceUnitDims(unitDims, indexingMaps, context); + if (!newIndexingMapAttr) + return genericOp.emitError("unable to compute modified indexing_maps"); + + // Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector newIteratorTypes; + for (auto attr : llvm::enumerate(iteratorTypes)) { + if (unitDims.count(attr.index())) + continue; + newIteratorTypes.push_back(attr.value()); + } + + rewriter.startRootUpdate(genericOp); + genericOp.indexing_mapsAttr(newIndexingMapAttr); + genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + rewriter.finalizeRootUpdate(genericOp); + return success(); + } +}; +} // namespace + +void GenericOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + ///////////////////// Operations defined with Tablegen ///////////////////////// // For such operations that do not correspond to library calls (i.e. defined in // LinalgOps.td), we define an overloaded `print` function and a diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -546,8 +546,9 @@ if (auto yieldOp = dyn_cast(op)) { // Lookup the value the yield operation is mapped to. Value yieldVal = yieldOp.getOperand(0); - auto clonedVal = mapper.lookup(yieldVal); - mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); + auto clonedVal = mapper.lookupOrNull(yieldVal); + if (clonedVal) + mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); continue; } rewriter.clone(op, mapper); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize | FileCheck %s +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s // CHECK-LABEL: func @memref_cast( func @memref_cast(%a: index, %b: index) -> memref { @@ -18,3 +18,114 @@ linalg.matmul(%3, %3, %3) : memref, memref, memref return %4: memref } + +// ----- + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops(%arg0 : tensor) -> tensor +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1 : f32) : + linalg.yield %arg1 : f32 + } : tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)> +// CHECK-LABEL: func @drop_one_trip_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1: f32) : + linalg.yield %arg1 : f32 + } : tensor<1x1xf32> -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>) +{ + linalg.generic #trait %arg0, %arg1 { + ^bb0(%arg2: f32, %arg3 : f32) : + linalg.yield %arg2 : f32 + } : memref<1x1xf32>, memref<1x1xf32> + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } : tensor<1x5xf32> -> tensor<5xf32> + return %0 : tensor<5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @leading_dim_1_canonicalization +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"]