diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -36,6 +36,11 @@ continue; indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); } + if (indexingMaps.empty()) { + // If there are no indexing maps, the operand can only be dropped + // if the op has no loops. + return linalgOp.getNumLoops() == 0; + } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -143,10 +143,10 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. -static void -generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, - AffineMap consumerToProducerLoopsMap, - OpOperand *fusedOperand, unsigned nloops) { +static void generateFusedElementwiseOpRegion( + RewriterBase &rewriter, GenericOp fusedOp, + AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, + unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { auto producer = cast(fusedOperand->get().getDefiningOp()); auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. @@ -202,9 +202,13 @@ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 6. All of the producer's output operands - for (BlockArgument bbArg : - producerBlock.getArguments().take_back(producer.getNumDpsInits())) - mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); + for (auto bbArg : llvm::enumerate( + producerBlock.getArguments().take_back(producer.getNumDpsInits()))) { + if (!preservedProducerResults.count(bbArg.index())) + continue; + mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(), + bbArg.value().getLoc())); + } // 7. All of consumer's output operands. for (BlockArgument bbArg : @@ -247,8 +251,11 @@ SmallVector fusedYieldValues; fusedYieldValues.reserve(producerYieldOp.getNumOperands() + consumerYieldOp.getNumOperands()); - for (auto producerYieldVal : producerYieldOp.getOperands()) - fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal)); + for (auto producerYieldVal : llvm::enumerate(producerYieldOp.getOperands())) { + if (preservedProducerResults.count(producerYieldVal.index())) + fusedYieldValues.push_back( + mapper.lookupOrDefault(producerYieldVal.value())); + } for (auto consumerYieldVal : consumerYieldOp.getOperands()) fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); rewriter.create(fusedOp.getLoc(), fusedYieldValues); @@ -269,6 +276,18 @@ // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); + /// Find the results of the producer that have uses outside of the consumer. + llvm::SmallDenseSet preservedProducerResults; + for (auto producerResult : llvm::enumerate(producer->getResults())) { + auto outputOperand = producer.getDpsInitOperand(producerResult.index()); + if (producer.payloadUsesValueFromOperand(outputOperand) || + !producer.canOpOperandsBeDropped(outputOperand) || + llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { + return user != consumer.getOperation(); + })) { + preservedProducerResults.insert(producerResult.index()); + } + } // Compute the fused operands list and indexing maps. SmallVector fusedInputOperands, fusedOutputOperands; @@ -276,9 +295,9 @@ SmallVector fusedIndexMaps; fusedInputOperands.reserve(producer.getNumDpsInputs() + consumer.getNumDpsInputs()); - fusedOutputOperands.reserve(producer.getNumDpsInits() + + fusedOutputOperands.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); - fusedResultTypes.reserve(producer.getNumDpsInits() + + fusedResultTypes.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); fusedIndexMaps.reserve(producer->getNumOperands() + consumer->getNumOperands()); @@ -313,13 +332,16 @@ } // 6. Collect all of the producer outputs. - for (OpOperand *opOperand : producer.getDpsInitOperands()) { - fusedOutputOperands.push_back(opOperand->get()); + for (auto opOperand : llvm::enumerate(producer.getDpsInitOperands())) { + if (!preservedProducerResults.count(opOperand.index())) + continue; + + fusedOutputOperands.push_back(opOperand.value()->get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - opOperand, producerResultIndexMap, + opOperand.value(), producerResultIndexMap, consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); - fusedResultTypes.push_back(opOperand->get().getType()); + fusedResultTypes.push_back(opOperand.value()->get().getType()); } // 7. All of consumer's output operands (skip operands: added by the builder). @@ -358,9 +380,9 @@ AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); - generateFusedElementwiseOpRegion(rewriter, fusedOp, - consumerToProducerLoopsMap, fusedOperand, - consumer.getNumLoops()); + generateFusedElementwiseOpRegion( + rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, + consumer.getNumLoops(), preservedProducerResults); return fusedOp.getOperation(); } diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @drop_unused_producer_result(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %0:2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%arg0, %arg0 : tensor, tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: f32): + %1 = arith.addf %b0, %b0 : f32 + %2 = arith.mulf %b0, %b0 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor, tensor) + %3 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%0#0, %arg1 : tensor, tensor) outs(%arg0 : tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: f32): + %4 = arith.subf %b0, %b1 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK-LABEL: func @drop_unused_producer_result +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK: return %[[FUSED_OP]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -75,6 +75,12 @@ llvm::cl::desc("Test fusion of generic operations."), llvm::cl::init(false)}; + Option fuseGenericOpsControl{ + *this, "fuse-generic-ops-control", + llvm::cl::desc( + "Test fusion of generic operations with a control function."), + llvm::cl::init(false)}; + Option fuseWithReshapeByExpansion{ *this, "fuse-with-reshape-by-expansion", llvm::cl::desc( @@ -108,6 +114,15 @@ func::FuncOp funcOp = this->getOperation(); if (fuseGenericOps) { + RewritePatternSet fusionPatterns(context); + auto controlFn = [](OpOperand *operand) { return true; }; + linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)); + return; + } + + if (fuseGenericOpsControl) { RewritePatternSet fusionPatterns(context); linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>);