diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" @@ -578,6 +579,140 @@ } }; +/// Perform a replacement of one iter OpOperand of an scf.for to the +/// `replacement` value which is expected to be the source of a tensor.cast. +/// tensor.cast ops are inserted inside the block to account for the type cast. +static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter, + OpOperand &operand, + Value replacement) { + Type oldType = operand.get().getType(), newType = replacement.getType(); + assert(oldType.isa() && newType.isa() && + "expected ranked tensor types"); + + // 1. Create new iter operands, exactly 1 is replaced. + ForOp forOp = cast(operand.getOwner()); + assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && + "expected an iter OpOperand"); + if (operand.get().getType() == replacement.getType()) + return forOp; + SmallVector newIterOperands; + for (OpOperand &opOperand : forOp.getIterOpOperands()) { + if (opOperand.getOperandNumber() == operand.getOperandNumber()) { + newIterOperands.push_back(replacement); + continue; + } + newIterOperands.push_back(opOperand.get()); + } + + // 2. Create the new forOp shell. + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(), + newIterOperands); + Block &newBlock = newForOp.region().front(); + SmallVector newBlockTransferArgs(newBlock.getArguments().begin(), + newBlock.getArguments().end()); + + // 3. Inject an incoming cast op at the beginning of the block for the bbArg + // corresponding to the `replacement` value. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(&newBlock, newBlock.begin()); + BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand( + newForOp->getOpOperand(operand.getOperandNumber())); + Value castIn = rewriter.create(newForOp.getLoc(), oldType, + newRegionIterArg); + newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; + + // 4. Steal the old block ops, mapping to the newBlockTransferArgs. + Block &oldBlock = forOp.region().front(); + rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); + + // 5. Inject an outgoing cast op at the end of the block and yield it instead. + auto clonedYieldOp = cast(newBlock.getTerminator()); + rewriter.setInsertionPoint(clonedYieldOp); + unsigned yieldIdx = + newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); + Value castOut = rewriter.create( + newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx)); + SmallVector newYieldOperands = clonedYieldOp.getOperands(); + newYieldOperands[yieldIdx] = castOut; + rewriter.create(newForOp.getLoc(), newYieldOperands); + rewriter.eraseOp(clonedYieldOp); + + // 6. Inject an outgoing cast op after the forOp. + rewriter.setInsertionPointAfter(newForOp); + SmallVector newResults = newForOp.getResults(); + newResults[yieldIdx] = rewriter.create( + newForOp.getLoc(), oldType, newResults[yieldIdx]); + + return newForOp; +} + +/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing +/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: +/// +/// ``` +/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor +/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) +/// -> (tensor) { +/// %2 = call @do(%iter_t0) : (tensor) -> tensor +/// scf.yield %2 : tensor +/// } +/// %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> +/// use_of(%2) +/// ``` +/// +/// folds into: +/// +/// ``` +/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) +/// -> (tensor<32x1024xf32>) { +/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor +/// %3 = call @do(%2) : (tensor) -> tensor +/// %4 = tensor.cast %3 : tensor to tensor<32x1024xf32> +/// scf.yield %4 : tensor<32x1024xf32> +/// } +/// use_of(%0) +/// ``` +struct ForOpTensorCastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp op, + PatternRewriter &rewriter) const override { + for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) { + OpOperand &iterOpOperand = std::get<0>(it); + auto incomingCast = iterOpOperand.get().getDefiningOp(); + if (!incomingCast) + continue; + if (!std::get<1>(it).hasOneUse()) + continue; + auto outgoingCastOp = + dyn_cast(*std::get<1>(it).user_begin()); + if (!outgoingCastOp) + continue; + + // Must be a tensor.cast op pair with matching types. + if (outgoingCastOp.getResult().getType() != + incomingCast.source().getType()) + continue; + + // Create a new ForOp with that iter operand replaced. + auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, + incomingCast.source()); + + // Insert outgoing cast and use it to replace the corresponding result. + rewriter.setInsertionPointAfter(newForOp); + SmallVector replacements = newForOp.getResults(); + unsigned returnIdx = + iterOpOperand.getOperandNumber() - op.getNumControlOperands(); + replacements[returnIdx] = rewriter.create( + op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]); + rewriter.replaceOp(op, replacements); + return success(); + } + return failure(); + } +}; + /// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and /// for which only the last loop iteration is actually visible outside of the /// loop. The canonicalization looks for a pattern such as: @@ -706,7 +841,7 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -580,3 +580,33 @@ // CHECK: return %[[FOR_RES]] : i32 return %0#0 : i32 } + +// ----- + +func private @do(%arg0: tensor) -> tensor + +// CHECK-LABEL: matmul_on_tensors +// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32> +// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32> +func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { + %c0 = constant 0 : index + %c32 = constant 32 : index + %c1024 = constant 1024 : index +// CHECK-NOT: tensor.cast +// CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) { +// CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor +// CHECK: %[[DONE:.*]] = call @do(%[[CAST]]) : (tensor) -> tensor +// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor to tensor<32x1024xf32> +// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32> + %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor + %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { + %2 = call @do(%iter_t0) : (tensor) -> tensor + scf.yield %2 : tensor + } +// CHECK-NOT: tensor.cast +// CHECK: %[[RES:.*]] = subtensor_insert %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> +// CHECK: return %[[RES]] : tensor<1024x1024xf32> + %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> + %res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> + return %res : tensor<1024x1024xf32> +}