Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -105,6 +105,10 @@ RewritePatternSet &patterns, LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions()); +/// Patterns to push reshape op towards the end of the graph in order to expose +/// more fusion opportunities. +void populatePushReshapeOpsPatterns(RewritePatternSet &patterns); + /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify Index: mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -932,6 +932,161 @@ } }; +static SmallVector +getReassociationIndices(ArrayRef maps) { + SmallVector reassociation; + for (AffineMap map : maps) { + ReassociationIndices indices; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + unsigned pos = map.getResult(i).cast().getPosition(); + indices.push_back(pos); + } + reassociation.push_back(indices); + } + return reassociation; +} + +/// Pattern to move rank reducing reshape after an elementwise linalg generic +/// op. This is useful to expose more fusion opportunities between named ops and +/// generic op. This can only be done if there is no broadcast or permuation +/// within the dimensions we need to merge. +/// +/// For example, +/// +/// %0 = linalg.tensor_reshape %A [ +/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] +/// : tensor<12544x16xf32> into tensor<112x112x16xf32> +/// %2 = linalg.generic {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = +/// ["parallel", "parallel", "parallel"]} { +/// } -> tensor<112x112x16xf32> +/// +/// into +/// +/// %2 = linalg.generic {indexing_maps = [ +/// affine_map<(d0, d1) -> (d0, d1)>, +/// affine_map<(d0, d1) -> (d1)>, +/// affine_map<(d0, d1) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 +/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { +/// } -> tensor<12544x16xf32> +/// %3 = linalg.tensor_reshape %2 [ +/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] +/// : tensor<12544x16xf32> into tensor<112x112x16xf32> +template +struct PushExpandingReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOpTy op, + PatternRewriter &rewriter) const override { + // Only apply to elementwise linalg on tensor. + if (!op.hasTensorSemantics() || + op.getNumParallelLoops() != op.getNumLoops()) + return failure(); + // Only support identity output maps. It could be extended to permuations if + // needed. + if (llvm::any_of(op.getOutputIndexingMaps(), + [](AffineMap map) { return !map.isIdentity(); })) + return failure(); + int64_t destRank = op.getNumParallelLoops(); + SmallVector newOperands = llvm::to_vector<4>(op.getInputs()); + TensorReshapeOp reshapeFound; + // 1. Look for tensor_reshape operands and figure out save the dimensions + // merged. + for (auto operand : llvm::enumerate(op.getInputs())) { + TensorReshapeOp reshapeOp = + operand.value().template getDefiningOp(); + if (!reshapeOp || reshapeOp.getSrcType().getRank() > + reshapeOp.getResultType().getRank()) { + continue; + } + // TODO: We could support non-identity map as long as the merged + // dimensions are still contiguous. + if (!op.getIndexingMaps()[operand.index()].isIdentity()) + continue; + if (reshapeFound) { + // Only support a second reshape op if it has the same reassociate maps. + if (reshapeFound.getReassociationMaps() == + reshapeOp.getReassociationMaps()) + newOperands[operand.index()] = reshapeOp.src(); + continue; + } + reshapeFound = reshapeOp; + newOperands[operand.index()] = reshapeOp.src(); + } + if (!reshapeFound) + return failure(); + + // Calculate the reassociation indices and rassociated reverse map. + SmallVector reassociation = + getReassociationIndices(reshapeFound.getReassociationMaps()); + SmallVector remap(destRank); + for (auto &indices : llvm::enumerate(reassociation)) { + for (int64_t index : indices.value()) { + remap[index] = indices.index(); + } + } + // 2. Verify that we can merge the dimensions in the linalg and that we + // don't need to create new reshapes operands. Inserting new reshape + // operands would defeat the purpose of the transformation. + for (auto operand : llvm::enumerate(op.getInputs())) { + if (operand.value() == newOperands[operand.index()]) { + AffineMap map = op.getIndexingMaps()[operand.index()]; + for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { + if (reassociation[remap[map.getDimPosition(i)]].size() > 1) + return failure(); + } + } + } + + // 3. Calculate the affine map remapping and the reassociation to apply to + // output tensors. + SmallVector newMaps; + unsigned newRank = reassociation.size(); + for (auto map : op.getIndexingMaps()) { + SmallVector newExprs; + for (auto expr : map.getResults()) { + unsigned position = expr.template cast().getPosition(); + // Skip dimension merged except for the last of the group. + if (reassociation[remap[position]].back() == position) { + newExprs.push_back( + getAffineDimExpr(remap[position], op.getContext())); + } + } + newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext())); + } + + // 4. Reshape the output tensors. + SmallVector newOutputs; + SmallVector newOutputTypes; + for (auto output : op.outputs()) { + Value newOutput = rewriter.create( + op->getLoc(), reshapeFound.getSrcType(), output, reassociation); + newOutputTypes.push_back(newOutput.getType()); + newOutputs.push_back(newOutput); + } + // 5. Create a new generic op with lowerer rank. + SmallVector iteratorTypes(newRank, + getParallelIteratorTypeName()); + auto newOp = + rewriter.create(op->getLoc(), newOutputTypes, newOperands, + newOutputs, newMaps, iteratorTypes); + rewriter.inlineRegionBefore(op.region(), newOp.region(), + newOp.region().begin()); + // 6. Reshape the so that the type matches the uses. + SmallVector newResults; + for (auto result : llvm::enumerate(newOp->getResults())) { + newResults.push_back(rewriter.create( + op->getLoc(), op.getOutputTensorTypes()[result.index()], + result.value(), reassociation)); + } + rewriter.replaceOp(op, newResults); + return success(); + } +}; + /// Pattern to fuse a tensor_reshape op with its consumer /// generic/indexed_generic op, when the reshape op is collapsing /// dimensions. The dimensionality of the loop in the consumer is expanded. @@ -1266,6 +1421,12 @@ TensorReshapeOp::getCanonicalizationPatterns(patterns, context); } +void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add, + PushExpandingReshape>(context); +} + std::unique_ptr mlir::createLinalgFusionOfTensorOpsPass() { return std::make_unique(); } Index: mlir/test/Dialect/Linalg/fusion-push-reshape.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> + +// CHECK-LABEL: func @reshape +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor) +// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor into tensor +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor<16xf32>) outs(%[[RI]] : tensor) +// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor into tensor +// CHECK: return %[[RR]] : tensor +func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor) -> tensor { + %0 = linalg.tensor_reshape %A [ + affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor into tensor + %2 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %B : tensor, tensor<16xf32>) + outs(%init : tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors + %s = subf %arg1, %arg2 : f32 + linalg.yield %s : f32 + } -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> + +// CHECK-LABEL: func @reshape_multiple +// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>) +// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> +// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>) +// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: return %[[RR]] : tensor<112x112x16xf32> +func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, + %C: tensor<16xf32>) -> tensor<112x112x16xf32> { + %0 = linalg.tensor_reshape %A [ + affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<12544x16xf32> into tensor<112x112x16xf32> + %1 = linalg.tensor_reshape %B [ + affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<12544x16xf32> into tensor<112x112x16xf32> + %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> + %3 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>) + outs(%2 : tensor<112x112x16xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %s = subf %arg1, %arg2 : f32 + %m = mulf %s, %arg3 : f32 + linalg.yield %m : f32 + } -> tensor<112x112x16xf32> + return %3 : tensor<112x112x16xf32> +} + +// ----- + +// Negative test, since the second source is broadcasted from d1 we cannot merge +// d0 and d1 dimensions +// CHECK-LABEL: func @reshape_negative +// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: linalg.generic +// CHECK: } -> tensor<112x112x16xf32> +func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { + %20 = linalg.tensor_reshape %A [ + affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<12544x16xf32> into tensor<112x112x16xf32> + %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32> + %22 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>) + outs(%21 : tensor<112x112x16xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors + %s = subf %arg1, %arg2 : f32 + linalg.yield %s : f32 + } -> tensor<112x112x16xf32> + return %22 : tensor<112x112x16xf32> +} Index: mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp =================================================================== --- mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp +++ mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp @@ -66,6 +66,22 @@ std::move(fusionPatterns)); } }; + +struct TestPushExpandingReshape + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnFunction() override { + MLIRContext *context = &this->getContext(); + FuncOp funcOp = this->getFunction(); + RewritePatternSet patterns(context); + linalg::populatePushReshapeOpsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } +}; } // namespace namespace test { @@ -74,6 +90,11 @@ "test-linalg-elementwise-fusion-patterns", "Test Linalg element wise operation fusion patterns"); } + +void registerTestPushExpandingReshape() { + PassRegistration testPushExpandingReshapePass( + "test-linalg-push-reshape", "Test Linalg reshape push patterns"); +} } // namespace test } // namespace mlir Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -78,6 +78,7 @@ void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgElementwiseFusion(); +void registerTestPushExpandingReshape(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); void registerTestLinalgGreedyFusion(); @@ -156,6 +157,7 @@ test::registerTestInterfaces(); test::registerTestLinalgCodegenStrategy(); test::registerTestLinalgElementwiseFusion(); + test::registerTestPushExpandingReshape(); test::registerTestLinalgFusionTransforms(); test::registerTestLinalgTensorFusionTransforms(); test::registerTestLinalgGreedyFusion();