diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -66,8 +66,7 @@ /// Function type which is used to control when to stop fusion. It is expected /// that OpOperand is not modified in the callback. The OpOperand is not marked /// as const to allow callers to use non-const methods. -using ControlFusionFn = - std::function; +using ControlFusionFn = std::function; /// Patterns for fusing linalg operation on tensors. @@ -111,6 +110,20 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Return true if two `linalg.generic` operations with producer/consumer +/// relationship through `fusedOperand` can be fused using elementwise op +/// fusion. +bool areElementwiseOpsFusable(OpOperand *fusedOperand); + +/// Fuse two `linalg.generic` operations that have a producer-consumer +/// relationship captured through `fusedOperand`. +/// `controlFn` allows additional control to check if the producer/consumer are +/// to be fused after the structural checks of whether producer can be fused +/// with consumer are satisfied. On success returns the fused operation. +FailureOr fuseElementwiseOps( + RewriterBase &rewriter, OpOperand *fusedOperand, + ControlFusionFn controlFn = [](OpOperand *fusedOperand) { return true; }); + /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -123,9 +123,7 @@ // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. for (int i = 0; i < numInputs; ++i) { - OpOperand *consumer = genericOp.getInputOperand(i); - OpResult producer = consumer->get().cast(); - if (!controlFn(producer, *consumer)) + if (!controlFn(genericOp.getInputOperand(i))) return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -65,8 +65,14 @@ } /// Conditions for elementwise fusion of generic operations. -static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, - OpOperand *consumerOpOperand) { +bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { + auto producer = fusedOperand->get().getDefiningOp(); + auto consumer = dyn_cast(fusedOperand->getOwner()); + + // Check producer and consumer are generic ops. + if (!producer || !consumer) + return false; + // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; @@ -78,19 +84,15 @@ // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (!consumer.isInputTensor(consumerOpOperand)) + if (!consumer.isInputTensor(fusedOperand)) return false; // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); + AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand); if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; - // Currently support only operations with single result. - if (producer.getNumOutputs() != 1) - return false; - // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = @@ -114,7 +116,7 @@ for (auto pair : llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) { Value operand = std::get<0>(pair); - if (operand == consumerOpOperand->get()) + if (operand == fusedOperand->get()) continue; AffineMap operandMap = std::get<1>(pair); addToCoveredDims(operandMap); @@ -136,12 +138,11 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void -generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, +generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, - OpOperand *consumerOpOperand, - unsigned nloops) { - auto producer = cast(consumerOpOperand->get().getDefiningOp()); - auto consumer = cast(consumerOpOperand->getOwner()); + OpOperand *fusedOperand, unsigned nloops) { + auto producer = cast(fusedOperand->get().getDefiningOp()); + auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); @@ -172,11 +173,11 @@ } } // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(consumerOpOperand) && + assert(consumer.isInputTensor(fusedOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( - consumerOpOperand->getOperandNumber())) // input assumption. + fusedOperand->getOperandNumber())) // input assumption. mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // Replacing consumerIdx requires getting the cloned, yielded, value from @@ -187,29 +188,22 @@ producerBlock.getArguments().take_front(producer.getNumInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 4.b. Producer output operand/map that is fused needs to be mapped to the - // producer bbArg if it is an "initTensor" (i.e. its value is actually read). - assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(producer.getOutputOperand(0))) { - BlockArgument bbArg = producerBlock.getArguments() - .drop_front(producer.getNumInputs()) - // TODO: bbArg index of - .front(); - mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() .take_front(consumer.getNumInputs()) - .drop_front(consumerOpOperand->getOperandNumber() + 1)) + .drop_front(fusedOperand->getOperandNumber() + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 6. All of consumer's output operands. + + // 6. All of the producer's output operands + for (BlockArgument bbArg : + producerBlock.getArguments().take_back(producer.getNumOutputs())) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); + + // 7. All of consumer's output operands. for (BlockArgument bbArg : consumerBlock.getArguments().take_back(consumer.getNumOutputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); - // 7. All of producer's output operands except the one fused. - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); // 8. Clone all producer operations except for the yield and index operations // to the fused operation. @@ -219,15 +213,15 @@ } // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just // forward the yield operand. - auto yieldOp = cast(producerBlock.getTerminator()); - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); - unsigned producerResultNumber = 0; + auto producerYieldOp = cast(producerBlock.getTerminator()); + unsigned producerResultNumber = + fusedOperand->get().cast().getResultNumber(); Value replacement = - mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); + mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); + // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. - if (replacement == yieldOp.getOperand(producerResultNumber)) { + if (replacement == producerYieldOp.getOperand(producerResultNumber)) { if (auto bb = replacement.dyn_cast()) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); @@ -235,91 +229,110 @@ assert(!producer->isAncestor(replacement.getDefiningOp()) && "yielded value must have been mapped"); } - mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), + mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()), replacement); // 10. Clone operations from the consumer to the fused op. - for (auto &op : consumerBlock.getOperations()) + for (auto &op : consumerBlock.without_terminator()) rewriter.clone(op, mapper); + // 11. Include the final yield (which is the remapped values for all the + // yield) + auto consumerYieldOp = cast(consumerBlock.getTerminator()); + SmallVector fusedYieldValues; + fusedYieldValues.reserve(producerYieldOp.getNumOperands() + + consumerYieldOp.getNumOperands()); + for (auto producerYieldVal : producerYieldOp.getOperands()) + fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal)); + for (auto consumerYieldVal : consumerYieldOp.getOperands()) + fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); + rewriter.create(fusedOp.getLoc(), fusedYieldValues); + // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && "Ill-formed GenericOp region"); } -static Optional> -fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, - const ControlFusionFn &controlFn, - PatternRewriter &rewriter) { - auto consumer = cast(consumerOpOperand->getOwner()); - if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || - !controlFn(producer->getResult(0), *consumerOpOperand)) - return llvm::None; +FailureOr +mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, + OpOperand *fusedOperand, + ControlFusionFn controlFn) { + auto consumer = cast(fusedOperand->getOwner()); + if (!consumer || !areElementwiseOpsFusable(fusedOperand)) { + return rewriter.notifyMatchFailure( + consumer, "failed elementwise op fusion structural checks"); + } + if (!controlFn(fusedOperand)) { + return rewriter.notifyMatchFailure(consumer, + "fusion blocked by user control fn"); + } + + auto producerResult = fusedOperand->get().cast(); + auto producer = cast(producerResult.getOwner()); // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(consumerOpOperand) && + assert(consumer.isInputTensor(fusedOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. - SmallVector fusedOperands; + SmallVector fusedInputOperands, fusedOutputOperands; + SmallVector fusedResultTypes; SmallVector fusedIndexMaps; - fusedOperands.reserve(producer->getNumOperands() + - consumer->getNumOperands()); - fusedIndexMaps.reserve(producer->getNumOperands() + - consumer->getNumOperands()); + fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs()); + fusedOutputOperands.reserve(producer.getNumOutputs() + + consumer.getNumOutputs()); + fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); + fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() + + consumer.getNumInputsAndOutputs()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). SmallVector consumerInputs = consumer.getInputOperands(); SmallVector::iterator it = - llvm::find(consumerInputs, consumerOpOperand); + llvm::find(consumerInputs, fusedOperand); assert(it != consumerInputs.end() && "expected to find the consumer operand"); for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } // 4. Splice in producer's input operands/maps. - assert(producer->getNumResults() == 1 && "expected single result producer"); AffineMap producerResultIndexMap = - producer.getTiedIndexingMap(producer.getOutputOperand(0)); + producer.getTiedIndexingMapForResult(producerResult); for (OpOperand *opOperand : producer.getInputOperands()) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, - consumer.getTiedIndexingMap(consumerOpOperand)); - fusedIndexMaps.push_back(map); - } - // 4.b. Producer output operand/map that is fused needs to be passed if it is - // an "initTensor" (i.e. its value is actually read). - assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(producer.getOutputOperand(0))) { - fusedOperands.push_back(producer.getOutputOperand(0)->get()); - // Compute indexing maps for the producer args in the fused operation. - AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - producer.getOutputOperand(0), producerResultIndexMap, - consumer.getTiedIndexingMap(consumerOpOperand)); + consumer.getTiedIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index // `consumerIdx`). for (OpOperand *opOperand : llvm::make_range(std::next(it), consumerInputs.end())) { - fusedOperands.push_back(opOperand->get()); + fusedInputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } - // 6. All of consumer's output operands (skip operands: added by the builder). - for (OpOperand *opOperand : consumer.getOutputOperands()) + + // 6. Collect all of the producer outputs. + for (OpOperand *opOperand : producer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand->get()); + AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + opOperand, producerResultIndexMap, + consumer.getTiedIndexingMap(fusedOperand)); + fusedIndexMaps.push_back(map); + fusedResultTypes.push_back(opOperand->get().getType()); + } + + // 7. All of consumer's output operands (skip operands: added by the builder). + for (OpOperand *opOperand : consumer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); - // 7. All of producer's output operands/maps except the one fused. - // TODO: allow fusion of multi-result producers. - assert(producer->getNumResults() == 1 && "expected single result producer"); + fusedResultTypes.push_back(opOperand->get().getType()); + } // Generate the fused op. - SmallVector consumerOutputs = consumer.getOutputOperands(); auto fusedOp = rewriter.create( - consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.getLoc(), fusedResultTypes, fusedInputOperands, + fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); @@ -328,13 +341,13 @@ // in the input, but going ahead here would result in verification errors. // So cleanup and abort. rewriter.eraseOp(fusedOp); - return llvm::None; + return rewriter.notifyMatchFailure( + fusedOp, "fused op failed loop bound computation check"); } // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index - AffineMap consumerResultIndexMap = - consumer.getTiedIndexingMap(consumerOpOperand); + AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); @@ -345,19 +358,9 @@ invProducerResultIndexMap.compose(consumerResultIndexMap); generateFusedElementwiseOpRegion(rewriter, fusedOp, - consumerToProducerLoopsMap, - consumerOpOperand, consumer.getNumLoops()); - return SmallVector(fusedOp->getResults()); -} - -static Optional> -fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, - GenericOp producer, const ControlFusionFn &controlFn) { - if (producer->getNumResults() != 1) - return llvm::None; - - return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, - rewriter); + consumerToProducerLoopsMap, fusedOperand, + consumer.getNumLoops()); + return fusedOp.getOperation(); } namespace { @@ -373,14 +376,12 @@ PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto producer = - dyn_cast_or_null(opOperand->get().getDefiningOp()); - if (!producer || !producer.hasTensorSemantics()) - continue; - Optional> fusedOpResults = - fuseElementwiseOps(rewriter, opOperand, producer, controlFn); - if (fusedOpResults) { - rewriter.replaceOp(genericOp, *fusedOpResults); + FailureOr fusedOp = + fuseElementwiseOps(rewriter, opOperand, controlFn); + if (succeeded(fusedOp)) { + auto replacements = fusedOp.getValue()->getResults().take_back( + genericOp.getNumResults()); + rewriter.replaceOp(genericOp, replacements); return success(); } } @@ -713,6 +714,10 @@ return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); + // Set insertion point to the generic op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(genericOp); + SmallVector expandedOpOperands; expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { @@ -792,7 +797,7 @@ SmallVector resultVals; for (OpResult opResult : genericOp->getOpResults()) { int64_t resultNumber = opResult.getResultNumber(); - if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { + if (resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = getReassociationForExpansion( genericOp.getTiedIndexingMap( @@ -834,7 +839,7 @@ // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || - (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) + (!controlFoldingReshapes(opOperand))) continue; Optional> replacementValues = @@ -865,18 +870,49 @@ LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. - GenericOp producer = reshapeOp.getSrc().getDefiningOp(); - if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, - producer.getOutputOperand(0)) || - !controlFoldingReshapes(producer->getResult(0), - reshapeOp->getOpOperand(0))) - return failure(); + OpResult producerResult = reshapeOp.getSrc().dyn_cast(); + if (!producerResult) { + return rewriter.notifyMatchFailure(reshapeOp, + "source not produced by an operation"); + } + + GenericOp producer = dyn_cast(producerResult.getOwner()); + if (!producer) { + return rewriter.notifyMatchFailure(reshapeOp, + "producer not a generic op"); + } + + if (!isFusableWithReshapeByDimExpansion( + producer, + producer.getOutputOperand(producerResult.getResultNumber()))) { + return rewriter.notifyMatchFailure( + reshapeOp, "failed preconditions of fusion with producer generic op"); + } + + if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) { + return rewriter.notifyMatchFailure(reshapeOp, + "fusion blocked by control function"); + } + Optional> replacementValues = fuseWithReshapeByExpansion( - producer, reshapeOp, producer.getOutputOperand(0), rewriter); + producer, reshapeOp, + producer.getOutputOperand(producerResult.getResultNumber()), rewriter); if (!replacementValues) - return failure(); - rewriter.replaceOp(reshapeOp, *replacementValues); + return rewriter.notifyMatchFailure(reshapeOp, + "fusion by expansion failed"); + + // Find the replacement for the reshape op. Since the replacements have the + // same type as the returns of the original generic op, the consumer reshape + // op can be replaced by the source of the collapse_shape op that defines + // the replacement. + Value reshapeReplacement = (*replacementValues) + [reshapeOp.getSrc().cast().getResultNumber()]; + if (auto collapseOp = + reshapeReplacement.getDefiningOp()) { + reshapeReplacement = collapseOp.getSrc(); + } + rewriter.replaceOp(reshapeOp, reshapeReplacement); + rewriter.replaceOp(producer, *replacementValues); return success(); } @@ -1469,7 +1505,7 @@ getCollapsableIterationSpaceDims(genericOp, opOperand, reshapeOp.getReassociationIndices()); if (collapsableIterationDims.empty() || - !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { + !controlFoldingReshapes(opOperand)) { continue; } @@ -1726,9 +1762,9 @@ RewritePatternSet patterns(context); // Add folding with reshape by expansion patterns. - ControlFusionFn defaultControlFn = [](const OpResult &producer, - const OpOperand &consumer) { - return producer.hasOneUse(); + ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + return producer && producer->hasOneUse(); }; // Add elementwise op fusion patterns. diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -662,12 +662,12 @@ linalg.yield %r : f64 } -> tensor<1x8xf64> - // CHECK-NEXT: %[[R:.*]] = linalg.generic + // CHECK-NEXT: %[[R:.*]]:2 = linalg.generic // CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32): // CHECK-NEXT: %[[A:.*]] = func.call @compute1(%[[BBA]]) : (f64) -> f64 // CHECK-NEXT: %[[B:.*]] = func.call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32 - // CHECK-NEXT: linalg.yield %[[B]] : i32 - // CHECK-NEXT: } -> tensor<1x8xi32> + // CHECK-NEXT: linalg.yield %[[A]], %[[B]] : f64, i32 + // CHECK-NEXT: } -> (tensor<1x8xf64>, tensor<1x8xi32>) %1 = linalg.generic { indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} @@ -678,7 +678,7 @@ linalg.yield %r : i32 } -> tensor<1x8xi32> - // CHECK-NEXT: return %[[R]] : tensor<1x8xi32> + // CHECK-NEXT: return %[[R]]#1 : tensor<1x8xi32> return %1 : tensor<1x8xi32> } @@ -948,7 +948,7 @@ // ----- -func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> { +func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> { %c1_i32 = arith.constant 1 : i32 %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>], @@ -971,10 +971,25 @@ } -> tensor<5000xi32> return %2 : tensor<5000xi32> } -// CHECK-LABEL: func @illegal_fusion( -// CHECK: %[[PRODUCER:.+]] = linalg.generic -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[PRODUCER]] +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: func @fusion_different_axes( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5000xi64> +// CHECK-SAME: %[[ARG1:.+]]: tensor<5000xi32> +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [5000] : tensor<5000xi64> +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [5000] : tensor<5000xi32> +// CHECK: %[[RESULT:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:.+]]: i64 +// CHECK-SAME: %[[B1:.+]]: i32 +// CHECK-DAG: %[[T0:.+]] = linalg.index 0 +// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64 +// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index +// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]] +// CHECK: linalg.yield %[[CAST1]], %[[EXTRACT]] +// CHECK: return %[[RESULT]]#1 // ----- @@ -995,7 +1010,7 @@ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %5 = arith.addf %arg1, %arg2 : f32 - linalg.yield %5 : f32 + linalg.yield %5 : f32 } -> tensor return %4 : tensor } @@ -1024,7 +1039,50 @@ %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor, tensor) outs (%6:tensor) { ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): %8 = arith.divf %arg1, %arg2 : f32 - linalg.yield %8 : f32 + linalg.yield %8 : f32 } -> tensor return %7 : tensor } + +// ----- + +#map = affine_map<() -> ()> +module { + func.func @fuse_multi_result_producer(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.init_tensor [] : tensor + %2:2 = linalg.generic { + indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []} + ins(%arg0, %arg1, %arg1 : tensor, tensor, tensor) outs(%0, %1 : tensor, tensor) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9: f32): + %4 = arith.addf %arg5, %arg6 : f32 + %5 = arith.addf %4, %arg7 : f32 + linalg.yield %4, %5 : f32, f32 + } -> (tensor, tensor) + %3 = linalg.generic { + indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%2#1, %arg1 : tensor, tensor) outs(%arg4 : tensor) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + %4 = arith.addf %arg5, %arg6 : f32 + %5 = arith.addf %4, %arg6 : f32 + linalg.yield %5 : f32 + } -> tensor + return %3 : tensor + } +} +// CHECK-LABEL: func.func @fuse_multi_result_producer +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[INIT]] : +// CHECK-NEXT: ^bb0 +// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CHECK-DAG: %[[T0:.+]] = arith.addf %[[B0]], %[[B1]] +// CHECK-DAG: %[[T1:.+]] = arith.addf %[[T0]], %[[B1]] +// CHECK-DAG: %[[T2:.+]] = arith.addf %[[T1]], %[[B1]] +// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] +// CHECK: linalg.yield %[[T3]] : f32 +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { @@ -58,5 +59,17 @@ // CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor // CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] // CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] -// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] -// CHECK: return %[[OP3]] +// CHECK: %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CHECK: return %[[OP3]]#1 + +// CANONICALIZE-LABEL: func @test_fusion_limit +// CANONICALIZE-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor +// CANONICALIZE: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] +// CANONICALIZE: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] +// CANONICALIZE: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CANONICALIZE: return %[[OP3]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -499,3 +499,43 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor) // CHECK: return %[[GENERIC]] + +// ----- + +func.func @reshape_as_consumer_permutation_with_multiple_results + (%a : tensor, %b : tensor) + -> (tensor, tensor) { + %c:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor, tensor) + outs(%a, %a : tensor, tensor) { + ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32): + %1 = arith.addf %arg0, %arg1 : f32 + linalg.yield %1, %1 : f32, f32 + } -> (tensor, tensor) + %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] + : tensor into tensor + %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] + : tensor into tensor + return %d, %e : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)> +// CHECK: func @reshape_as_consumer_permutation_with_multiple_results +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}} +// CHECK-DAG: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}} +// CHECK-DAG: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}} +// CHECK-DAG: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}} +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : +// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -35,14 +35,18 @@ } template -static bool setFusedOpOperandLimit(const OpResult &producer, - const OpOperand &consumer) { +static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + + Operation *consumer = fusedOperand->getOwner(); SetVector fusedOpOperands; - if (producer.getOwner()->getNumResults() != 1) + if (producer->getNumResults() != 1) return false; - addOperands(consumer.getOwner(), fusedOpOperands); - fusedOpOperands.remove(producer); - addOperands(producer.getOwner(), fusedOpOperands); + addOperands(consumer, fusedOpOperands); + fusedOpOperands.remove(producer->getResult(0)); + addOperands(producer, fusedOpOperands); return fusedOpOperands.size() <= limit; } @@ -113,8 +117,7 @@ if (fuseWithReshapeByExpansion) { RewritePatternSet fusionPatterns(context); linalg::populateFoldReshapeOpsByExpansionPatterns( - fusionPatterns, [](const OpResult & /*producer*/, - OpOperand & /*consumer*/) { return true; }); + fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(fusionPatterns)))) return signalPassFailure(); @@ -125,15 +128,19 @@ RewritePatternSet fusionPatterns(context); linalg::ControlFusionFn controlReshapeFusionFn = - [](const OpResult &producer, OpOperand &consumer) { - if (auto collapseOp = - producer.getDefiningOp()) { + [](OpOperand *fusedOperand) { + auto producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + + if (auto collapseOp = dyn_cast(producer)) { if (!collapseOp.getSrc().getDefiningOp()) { return false; } } - if (auto expandOp = - dyn_cast(consumer.getOwner())) { + + auto consumer = fusedOperand->getOwner(); + if (auto expandOp = dyn_cast(consumer)) { if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); @@ -155,18 +162,17 @@ if (fuseWithReshapeByCollapsing) { RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( - patterns, [](const OpResult & /*producer*/, - OpOperand & /*consumer*/) { return true; }); + patterns, [](OpOperand * /*fusedOperand */) { return true; }); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } if (fuseWithReshapeByCollapsingWithControlFn) { RewritePatternSet patterns(context); - linalg::ControlFusionFn controlFn = [](const OpResult &producer, - OpOperand &consumer) -> bool { - if (isa(producer.getDefiningOp())) { + linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool { + auto producer = fusedOperand->get().getDefiningOp(); + if (isa(producer)) { // Skip fusing the first operand. - return consumer.getOperandNumber(); + return fusedOperand->getOperandNumber(); } return true; };