diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -24,6 +24,8 @@ class OwningRewritePatternList; class Pass; +std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); + std::unique_ptr> createLinalgFusionPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); @@ -59,6 +61,11 @@ void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors. +void populateLinalgFoldUnitExtentDimsPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,11 @@ include "mlir/Pass/PassBase.td" +def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { + let summary = "Remove unit-extent dimension in Linalg ops on tensors"; + let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; +} + def LinalgFusion : FunctionPass<"linalg-fusion"> { let summary = "Fuse operations in the linalg dialect"; let constructor = "mlir::createLinalgFusionPass()"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -44,6 +44,80 @@ template static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); +namespace mlir { +// Forward declaring this method that is also declared in +// mlir/Dialect/Linalg/Passes.h to avoid having to include that file here as +// well. +void populateLinalgFoldUnitExtentDimsPatterns( + MLIRContext *context, OwningRewritePatternList &patters); +} // namespace mlir + +/// Determines whether it is possible to fold it away in the parent Linalg op: +/// +/// ```mlir +/// %1 = memref_cast %0 : memref<8x16xf32> to memref +/// %2 = linalg.slice %1 ... : memref ... +/// // or +/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +/// to memref +/// linalg.generic(%1 ...) : memref ... +/// ``` +/// +/// into +/// +/// ```mlir +/// %2 = linalg.slice %0 ... : memref<8x16xf32> ... +/// // or +/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +/// ``` +/// +static bool canFold(MemRefCastOp castOp) { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + MemRefType resultType = castOp.getType().dyn_cast(); + + // If we don't have MemRefType as source and destination, bail out. + if (!sourceType || !resultType) + return false; + + // If resultType has a map, it needs to be the same as the source type to + // canonicalize. + if (!resultType.getAffineMaps().empty() && + sourceType.getAffineMaps() != resultType.getAffineMaps()) + return false; + + // Ensure that: + // 1. source is static + // 2. source and target have the same rank (will be extended when needed) + // 3. if result is partially static, ensure sizes match. + if (!sourceType.hasStaticShape() || + sourceType.getRank() != resultType.getRank()) + return false; + + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto sourceSize = std::get<0>(it); + auto resultSize = std::get<1>(it); + if (ShapedType::isDynamic(resultSize)) + continue; + if (sourceSize != resultSize) + return false; + } + + // If source has a map, it can only canonicalize if it is the canonical + // strided layout map. + if (sourceType.getAffineMaps().empty()) + return true; + + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(sourceType, strides, offset); + (void)res; + assert(succeeded(res)); + auto stridedMap = + makeStridedLinearLayoutMap(strides, offset, castOp.getContext()); + AffineMap sourceMap = sourceType.getAffineMaps().front(); + return sourceMap == stridedMap; +} + /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast) -> someop @@ -169,6 +243,150 @@ return success(); } }; + +struct UnitExtentReplacementInfo { + RankedTensorType type; + AffineMap indexMap; + ArrayAttr reassociation; +}; +} // namespace + +/// Utility function for replacing operands/results to a linalg generic +/// operation on tensors with unit-extent dimensions. These can be replaced with +/// an operand/result with the unit-extent dimenion removed. This is only done +/// if the indexing map used to access that dimension has a AffineConstantExpr +/// of value 0. Given the `type` of an result/operand of a Linalg op, and its +/// `indexMap` the utility function returns: +/// - the new type with dimensions of size 1 removed. +/// - modified index map that can be used to access the replaced result/operand +/// - the reassociation that converts from the original tensor type to the +/// modified tensor type. +static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, + RankedTensorType type, + MLIRContext *context) { + ArrayRef shape = type.getShape(); + ArrayRef exprs = indexMap.getResults(); + SmallVector reassociations; + SmallVector reassociationMaps; + SmallVector newIndexExprs; + SmallVector newShape; + + int64_t origRank = type.getRank(); + auto isUnitExtent = [&](int64_t dim) -> bool { + return dim < origRank && exprs[dim].isa() && + exprs[dim].cast().getValue() == 0 && + shape[dim] == 1; + }; + + unsigned dim = 0; + // Fold dimensions that are unit-extent at the beginning of the tensor. + while (isUnitExtent(dim)) { + reassociations.push_back(getAffineDimExpr(dim++, context)); + } + for (; dim < origRank; ++dim) { + reassociations.push_back(getAffineDimExpr(dim, context)); + newIndexExprs.push_back(exprs[dim]); + newShape.push_back(shape[dim]); + // Fold all following dimensions that are unit-extent. + while (isUnitExtent(dim + 1)) { + reassociations.push_back(getAffineDimExpr(++dim, context)); + } + reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( + origRank, /*numSymbols = */ 0, reassociations, context))); + reassociations.clear(); + } + UnitExtentReplacementInfo info = { + RankedTensorType::get(newShape, type.getElementType()), + AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), + newIndexExprs, context), + ArrayAttr::get(reassociationMaps, context)}; + return info; +} + +namespace { +/// Pattern to replace tensors operands/results that are unit extents. +struct ReplaceUnitExtentTensors : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + + MLIRContext *context = rewriter.getContext(); + Location loc = genericOp.getLoc(); + ArrayRef indexingMaps = genericOp.getIndexingMaps(); + + SmallVector newIndexingMaps; + SmallVector reassociationMaps; + SmallVector newInputOutputTypes; + bool doCanonicalization = false; + for (auto it : + llvm::zip(indexingMaps, genericOp.getInputOutputShapedTypes())) { + auto replacementInfo = replaceUnitExtents( + std::get<0>(it), std::get<1>(it).cast(), context); + reassociationMaps.push_back(replacementInfo.reassociation); + newIndexingMaps.push_back(replacementInfo.indexMap); + newInputOutputTypes.push_back(replacementInfo.type); + doCanonicalization = + doCanonicalization || replacementInfo.type != std::get<1>(it); + } + + // If the indexing maps of the result operation are not invertible (i.e. not + // legal), abort. + if (!doCanonicalization || + !inversePermutation(concatAffineMaps(newIndexingMaps))) + return failure(); + + // If any operand types change, insert a reshape to convert from the + // original type to the new type. + SmallVector newOperands; + newOperands.reserve(genericOp.getNumOperands()); + for (auto operand : llvm::enumerate(genericOp.getOperands())) { + if (operand.value().getType() == newInputOutputTypes[operand.index()]) { + newOperands.push_back(operand.value()); + } else { + newOperands.push_back(rewriter.create( + loc, newInputOutputTypes[operand.index()], operand.value(), + reassociationMaps[operand.index()])); + } + } + + // If any operand types change, insert a reshape to convert from the + // original type to the new type. + SmallVector resultTypes; + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq(0, genericOp.getNumResults())) + resultTypes.push_back( + newInputOutputTypes[i + genericOp.getNumOperands()]); + GenericOp replacementOp = rewriter.create( + loc, resultTypes, newOperands, genericOp.args_in(), + genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), + /*doc = */ nullptr, + /*library_call = */ nullptr); + rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), + replacementOp.region().begin()); + + // If any result tensor has a modified shape, then add reshape to recover + // the original shape. + SmallVector resultReplacements; + for (auto result : llvm::enumerate(replacementOp.getResults())) { + unsigned index = result.index() + replacementOp.getNumOperands(); + RankedTensorType origResultType = genericOp.getResult(result.index()) + .getType() + .cast(); + if (origResultType != result.value().getType()) { + resultReplacements.push_back(rewriter.create( + loc, origResultType, result.value(), reassociationMaps[index])); + } else { + resultReplacements.push_back(result.value()); + } + } + rewriter.replaceOp(genericOp, resultReplacements); + return success(); + } +}; + } // namespace void GenericOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -1255,6 +1473,14 @@ return {}; } +/// Patterns that are used to canonicalize the use of unit-extent dims for +/// broadcasting. +void mlir::populateLinalgFoldUnitExtentDimsPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, FoldUnitDimLoops, + ReplaceUnitExtentTensors>(context); +} + //===----------------------------------------------------------------------===// // Auto-generated Linalg named ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -802,6 +802,17 @@ } }; +/// Pass that removes unit-extent dims within generic ops. +struct LinalgFoldUnitExtentDimsPass + : public LinalgFoldUnitExtentDimsBase { + void runOnFunction() override { + OwningRewritePatternList patterns; + FuncOp funcOp = getFunction(); + populateLinalgFoldUnitExtentDimsPatterns(funcOp.getContext(), patterns); + applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); + } +}; + /// Pass that fuses generic ops on tensors. Used only for testing. struct FusionOfTensorOpsPass : public LinalgFusionOfTensorOpsBase { @@ -824,6 +835,11 @@ context); } +std::unique_ptr> +mlir::createLinalgFoldUnitExtentDimsPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createLinalgFusionPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -0,0 +1,131 @@ +// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops(%arg0 : tensor) -> tensor +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1 : f32) : + linalg.yield %arg1 : f32 + } : tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CHECK-DAG: #[[MAP6:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)> +// CHECK-LABEL: func @drop_one_trip_loops +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP4]], #[[MAP5]], #[[MAP6]]] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> +{ + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg1: f32) : + linalg.yield %arg1 : f32 + } : tensor<1x1xf32> -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> ()> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#accesses = [ + affine_map<(d0) -> (0, d0)>, + affine_map<(d0) -> (d0)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel"], + library_call = "some_external_fn" +} + +func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } : tensor<1x5xf32> -> tensor<5xf32> + return %0 : tensor<5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @leading_dim_1_canonicalization +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, d1)>, + affine_map<(d0, d1) -> (d0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : + tensor<5xf32> into tensor<1x5xf32> + %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : + tensor<5xf32> into tensor<5x1xf32> + %2 = linalg.generic #trait %0, %1 { + ^bb0(%arg2: f32, %arg3: f32): + %3 = addf %arg2, %arg3 : f32 + linalg.yield %3 : f32 + } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> + return %2 : tensor<5x5xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_test +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-NOT: linalg.tensor_reshape