diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -37,6 +37,11 @@ if (producer.getNumParallelLoops() != producer.getNumLoops()) return false; + // Only allow fusing the producer of an input operand for now. + // TODO: allow fusing the producer of an output operand. + if (consumerIdx >= consumer.getNumInputs()) + return false; + // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); @@ -120,60 +125,88 @@ isa(consumer.getOperation())) ? std::max(producer.getNumLoops(), consumer.getNumLoops()) : 0; - // Firstly, add all the indices to the block arguments. + + // 0. Firstly, add all the indices to the block arguments. for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i) fusedBlock->addArgument(rewriter.getIndexType()); - // Map the arguments for the unmodified args from the consumer. - for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { - if (consumerArg.index() == consumerIdx + numConsumerIndices) { - // Map the arguments for the args from the producer. - for (auto producerArg : - llvm::enumerate(producerBlock.getArguments().take_front( - producer.getNumInputs() + numProducerIndices))) { - // If producer is an indexed_generic op, map the indices from consumer - // loop to producer loop (because the fusedOp is built based on - // consumer's perspective). - if (producerArg.index() < numProducerIndices) { - auto newIndex = rewriter.create( - producer.getLoc(), - consumerToProducerLoopsMap.getSubMap(producerArg.index()), - fusedBlock->getArguments().take_front(numFusedOpIndices)); - mapper.map(producerArg.value(), newIndex); - } else { - mapper.map(producerArg.value(), - fusedBlock->addArgument(producerArg.value().getType())); - } - } - continue; - } - - // If consumer is an indexed_generic op, map the indices to the block - // arguments directly. Otherwise, add the same type of argument and map to - // it. - if (consumerArg.index() < numConsumerIndices) { - mapper.map(consumerArg.value(), - fusedBlock->getArgument(consumerArg.index())); - } else { - mapper.map(consumerArg.value(), - fusedBlock->addArgument(consumerArg.value().getType())); - } + // 1. Map consumer indices to fusedBlock indices 1-1. + mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices), + fusedBlock->getArguments().take_front(numConsumerIndices)); + // 2. Embed producer indices into fusedBlock index space 1-1. + for (auto it : + llvm::zip(producerBlock.getArguments().take_front(numProducerIndices), + fusedBlock->getArguments().take_front(numProducerIndices))) { + auto newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()), + fusedBlock->getArguments().take_front(numFusedOpIndices)); + mapper.map(std::get<0>(it), newIndex); } - - // Add operations from producer (except the yield operation) to the fused + // TODO: allow fusing the producer of an output operand. + assert(consumerIdx < consumer.getNumInputs() && + "expected producer of input operand"); + // 3. Consumer input operands up to consumerIdx (exclusive). + for (BlockArgument bbArg : consumerBlock.getArguments() + .drop_front(numConsumerIndices) + .take_front(consumer.getNumInputs()) + .take_front(consumerIdx)) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + // 4. Replacing consumerIdx requires getting the cloned, yielded, value from + // the (cloned) producer block. + // 5. Splice in producer's input operands. + for (BlockArgument bbArg : producerBlock.getArguments() + .drop_front(numProducerIndices) + .take_front(producer.getNumInputs())) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + // 6. Remaining consumer's input operands (drop past index `consumerIdx`). + for (BlockArgument bbArg : consumerBlock.getArguments() + .drop_front(numConsumerIndices) + .take_front(consumer.getNumInputs()) + .drop_front(consumerIdx + 1)) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + // 7. All of consumer's output operands. + for (BlockArgument bbArg : + consumerBlock.getArguments().take_back(consumer.getNumOutputs())) + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + // 8. All of producer's output operands except the one fused. + // TODO: allow fusion of multi-result producers. + assert(producer->getNumResults() == 1 && "expected single result producer"); + // for (BlockArgument bbArg : + // producerBlock.getArguments().take_back(producer.getNumOutputs())) + // mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + + // 9. Clone operations from producer (except the yield operation) to the fused // op. - for (auto &op : producerBlock.getOperations()) { - if (auto yieldOp = dyn_cast(op)) { - // Lookup the value the yield operation is mapped to. - Value yieldVal = yieldOp.getOperand(0); - if (Value clonedVal = mapper.lookupOrNull(yieldVal)) - mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), - clonedVal); - continue; - } + for (auto &op : producerBlock.without_terminator()) rewriter.clone(op, mapper); + // 10. Now we can map the consumerBlock's `consumerIdx` block argument. Just + // forward the yield operand. + auto yieldOp = cast(producerBlock.getTerminator()); + // TODO: allow fusion of multi-result producers. + assert(producer->getNumResults() == 1 && "expected single result producer"); + unsigned producerResultNumber = 0; + Value replacement = + mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); + // Sanity checks, if replacement is not already in the mapper then it must be + // produced outside. + if (replacement == yieldOp.getOperand(producerResultNumber)) { + if (auto bb = replacement.dyn_cast()) + assert(bb.getOwner() != &producerBlock && + "yielded block argument must have been mapped"); + else + assert(!producer->isAncestor(replacement.getDefiningOp()) && + "yielded value must have been mapped"); } + mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), + replacement); + // 11. Clone operations from the consumer to the fused op. for (auto &op : consumerBlock.getOperations()) rewriter.clone(op, mapper); + + // Sanity checks. + assert(fusedBlock->getNumArguments() == + fusedOp->getNumOperands() + numFusedOpIndices && + "Ill-formed LinalgOp region"); } static Optional> @@ -856,8 +889,6 @@ op->setOperands(fusedOperands); op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps)); rewriter.finalizeRootUpdate(op); - if (reshapeOp.use_empty()) - rewriter.eraseOp(reshapeOp); return success(); } return failure(); @@ -897,8 +928,6 @@ if (!replacementValues) return failure(); rewriter.replaceOp(genericOp, replacementValues.getValue()); - if (reshapeOp.use_empty()) - rewriter.eraseOp(reshapeOp); return success(); } return failure(); @@ -963,8 +992,6 @@ rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, fusedRegion.begin()); rewriter.replaceOp(reshapeOp, fusedOp->getResults()); - if (producer.use_empty()) - rewriter.eraseOp(producer); return success(); } }; @@ -995,8 +1022,6 @@ if (!replacementValues) return failure(); rewriter.replaceOp(reshapeOp, replacementValues.getValue()); - if (producer.use_empty()) - rewriter.eraseOp(producer); return success(); } }; @@ -1057,8 +1082,6 @@ rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion, fusedRegion.begin(), mapping); rewriter.replaceOp(linalgOp, fusedOp->getResults()); - if (constantOp.use_empty()) - rewriter.eraseOp(constantOp); return success(); } return failure(); @@ -1092,15 +1115,14 @@ PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : op.getShapedOpOperands()) { - Operation *producer = opOperand.get().getDefiningOp(); - if (!producer) + LinalgOp producerOp = + dyn_cast_or_null(opOperand.get().getDefiningOp()); + if (!producerOp || !producerOp.hasTensorSemantics()) continue; Optional> fusedOpResults = fuseTensorOps(rewriter, opOperand); if (fusedOpResults) { rewriter.replaceOp(op, *fusedOpResults); - if (producer->use_empty()) - rewriter.eraseOp(producer); return success(); } } @@ -1115,8 +1137,7 @@ Operation *op = getOperation(); OwningRewritePatternList patterns(op->getContext()); populateLinalgTensorOpsFusionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), - /*useTopDown=*/false); + (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -578,3 +578,41 @@ // CHECK: %[[T4:.+]] = addf %[[T3]], %[[T2]] : f32 // CHECK: linalg.yield %[[T4]] // CHECK: return %[[RES]] + +// ----- + +// CHECK-LABEL: func @sigmoid_dynamic_dim( +// CHECK: %[[RES:.*]] = linalg.generic +// CHECK-NOT: linalg.generic +// CHECK: return %[[RES]] +func @sigmoid_dynamic_dim(%0: tensor) -> tensor { + %cp5 = constant 5.000000e-01 : f32 + %c0 = constant 0 : index + %shape = shape.shape_of %0 : tensor -> tensor + %extend = shape.to_extent_tensor %shape : tensor -> tensor<2xindex> + %extracted = tensor.extract %extend[%c0] : tensor<2xindex> + %init0 = linalg.init_tensor [%extracted, 1] : tensor + %1 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + outs(%init0 : tensor) { + ^bb0(%a: f32): // no predecessors + linalg.yield %cp5 : f32 + } -> tensor + %d0 = memref.dim %0, %c0 : tensor + %init1 = linalg.init_tensor [%d0, 1] : tensor + %2 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%0, %1 : tensor, tensor) + outs(%init1 : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): // no predecessors + %m = mulf %a, %b : f32 + linalg.yield %m : f32 + } -> tensor + return %2 : tensor +}