diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -1594,11 +1595,91 @@ } }; +struct FoldTensorCastOfOutputIntoForallOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + struct TypeCast { + Type srcType; + Type dstType; + }; + + LogicalResult matchAndRewrite(scf::ForallOp forallOp, + PatternRewriter &rewriter) const final { + llvm::SmallMapVector tensorCastProducers; + llvm::SmallVector newOutputTensors = forallOp.getOutputs(); + for (auto en : llvm::enumerate(newOutputTensors)) { + auto castOp = en.value().getDefiningOp(); + if (!castOp) + continue; + + // Only casts that that preserve static information, i.e. will make the + // loop result type "more" static than before, will be folded. + if (!tensor::preservesStaticInformation(castOp.getDest().getType(), + castOp.getSource().getType())) { + continue; + } + + tensorCastProducers[en.index()] = + TypeCast{castOp.getSource().getType(), castOp.getType()}; + newOutputTensors[en.index()] = castOp.getSource(); + } + + if (tensorCastProducers.empty()) + return failure(); + + // Create new loop. + Location loc = forallOp.getLoc(); + auto newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), + [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { + auto castBlockArgs = + llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); + for (auto [index, cast] : tensorCastProducers) { + Value &oldTypeBBArg = castBlockArgs[index]; + oldTypeBBArg = nestedBuilder.create( + nestedLoc, cast.dstType, oldTypeBBArg); + } + + // Move old body into new parallel loop. + SmallVector ivsBlockArgs = + llvm::to_vector(bbArgs.take_front(forallOp.getRank())); + ivsBlockArgs.append(castBlockArgs); + rewriter.mergeBlocks(forallOp.getBody(), + bbArgs.front().getParentBlock(), ivsBlockArgs); + }); + + // After `mergeBlocks` happened, the destinations in the terminator were + // mapped to the tensor.cast old-typed results of the output bbArgs. The + // destination have to be updated to point to the output bbArgs directly. + auto terminator = newForallOp.getTerminator(); + for (auto [yieldingOp, outputBlockArg] : + llvm::zip(terminator.getYieldingOps(), + newForallOp.getOutputBlockArguments())) { + auto insertSliceOp = cast(yieldingOp); + insertSliceOp.getDestMutable().assign(outputBlockArg); + } + + // Cast results back to the original types. + rewriter.setInsertionPointAfter(newForallOp); + SmallVector castResults = newForallOp.getResults(); + for (auto &item : tensorCastProducers) { + Value &oldTypeResult = castResults[item.first]; + oldTypeResult = rewriter.create(loc, item.second.dstType, + oldTypeResult); + } + rewriter.replaceOp(forallOp, castResults); + return success(); + } +}; + } // namespace void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1651,3 +1651,52 @@ // CHECK: %[[EMPTY:.*]] = tensor.empty // CHECK: return %[[EMPTY]] +// ----- + +func.func @fold_tensor_cast_into_forall( + %in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> { + %cst = arith.constant dense<[100500]> : tensor<1xi32> + + + %out_cast = tensor.cast %out : tensor<2xi32> to tensor + %result = scf.forall (%i) = (0) to (2) step (1) + shared_outs (%out_ = %out_cast) -> tensor { + + scf.forall.in_parallel { + tensor.parallel_insert_slice %cst into %out_[%i] [1] [1] + : tensor<1xi32> into tensor + } + } + %result_cast = tensor.cast %result : tensor to tensor<2xi32> + func.return %result_cast : tensor<2xi32> +} +// CHECK-LABEL: @fold_tensor_cast_into_forall +// CHECK-NOT: tensor.cast +// CHECK: parallel_insert_slice +// CHECK-SAME: : tensor<1xi32> into tensor<2xi32> +// CHECK-NOT: tensor.cast + +// ----- + +func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall( + %in: tensor, %out: tensor) -> tensor { + %cst = arith.constant dense<[100500]> : tensor<1xi32> + + + %out_cast = tensor.cast %out : tensor to tensor<2xi32> + %result = scf.forall (%i) = (0) to (2) step (1) + shared_outs (%out_ = %out_cast) -> tensor<2xi32> { + + scf.forall.in_parallel { + tensor.parallel_insert_slice %cst into %out_[%i] [1] [1] + : tensor<1xi32> into tensor<2xi32> + } + } + %result_cast = tensor.cast %result : tensor<2xi32> to tensor + func.return %result_cast : tensor +} +// CHECK-LABEL: @do_not_fold_tensor_cast_ +// CHECK: tensor.cast +// CHECK: parallel_insert_slice +// CHECK-SAME: : tensor<1xi32> into tensor<2xi32> +// CHECK: tensor.cast