diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -45,6 +45,7 @@ }]; + let hasCanonicalizer = 1; let hasConstantMaterializer = 1; let dependentDialects = [ "AffineDialect", diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1745,61 +1745,6 @@ } }; -struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - // If no operand comes from a tensor::CastOp and can be folded then fail. - bool hasTensorCastOperand = - llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (opOperand.get().isa()) - return false; - auto castOp = opOperand.get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); - if (!hasTensorCastOperand) - return failure(); - - SmallVector newResultTypes; - newResultTypes.reserve(op->getNumResults()); - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - // Inputs may fold. - for (auto *input : op.getDpsInputOperands()) { - auto tensorCastOp = input->get().getDefiningOp(); - newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) - ? tensorCastOp.getSource() - : input->get()); - } - // Init tensors may fold, in which case the resultType must also change. - for (auto *output : op.getDpsInitOperands()) { - auto tensorCastOp = output->get().getDefiningOp(); - bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get()); - if (!newOperands.back().getType().isa()) - newResultTypes.push_back(newOperands.back().getType()); - } - // Clone op. - Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); - SmallVector replacements; - replacements.reserve(newOp->getNumResults()); - for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { - Value oldResult = std::get<0>(result); - Value newResult = std::get<1>(result); - if (newResult.getType() != oldResult.getType()) { - replacements.push_back(rewriter.create( - op->getLoc(), oldResult.getType(), newResult)); - } else { - replacements.push_back(newResult); - } - } - rewriter.replaceOp(op, replacements); - - return success(); - } -}; - /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has /// result that is more static than the linalg op. struct FoldTensorCastConsumerOp : public OpRewritePattern { @@ -2023,8 +1968,7 @@ void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add( - getContext()); + InferStaticShapeOfOperands>(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3433,6 +3433,89 @@ return success(); } +//===----------------------------------------------------------------------===// +// Common Canonicalizers and Folders. +//===----------------------------------------------------------------------===// + +/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if +/// the `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = consumer %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = consumer %0 ... : tensor<8x16xf32> ... +/// ``` +/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp +/// can add the pattern to their canonicalizers. +struct FoldTensorCastProducerOp + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + DestinationStyleOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(DestinationStyleOpInterface op, + PatternRewriter &rewriter) const override { + // InsertSliceOp has its own logic about folding tensor.cast ops. + if (isa(op.getOperation())) + return failure(); + + // If no operand comes from a tensor::CastOp and can be folded then fail. + bool hasTensorCastOperand = + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (opOperand.get().isa()) + return false; + auto castOp = opOperand.get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); + if (!hasTensorCastOperand) + return failure(); + + SmallVector newResultTypes; + newResultTypes.reserve(op->getNumResults()); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + for (OpOperand &opOperand : op->getOpOperands()) { + auto tensorCastOp = opOperand.get().getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); + if (op.isDpsInit(&opOperand) && + !newOperands.back().getType().isa()) + newResultTypes.push_back(newOperands.back().getType()); + } + + // Clone op. + Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (newResult.getType() != oldResult.getType()) { + replacements.push_back(rewriter.create( + op->getLoc(), oldResult.getType(), newResult)); + } else { + replacements.push_back(newResult); + } + } + rewriter.replaceOp(op, replacements); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TensorDialect +//===----------------------------------------------------------------------===// + +void TensorDialect::getCanonicalizationPatterns( + RewritePatternSet &results) const { + results.add(getContext()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -18,11 +18,9 @@ // CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]] // CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor // CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32> -// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] // CHECK: %[[SUB_RES:.*]] = tensor.pack -// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]] -// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] // CHECK: scf.yield %[[INSERT]] : tensor<4x8x32x32xf32> // CHECK: } // CHECK: scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32> @@ -55,12 +53,10 @@ // CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]]) // CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]] // CHECK-SAME: [0, %[[IN_C]]] [128, %[[IN_C_SZ]]] -// CHECK: %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]] // CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8] -// CHECK: %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]] // CHECK: tensor.pack -// CHECK-SAME: %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] -// CHECK-SAME: into %[[CAST_OUT]] +// CHECK-SAME: %[[INPUT_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] +// CHECK-SAME: into %[[OUTPUT_SLICE]] func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32> return %0 : tensor<32x4x32x8xf32> @@ -87,14 +83,11 @@ // CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]]) // CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]]) // CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1] -// CHECK: %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]] // CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1] -// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] // CHECK: %[[SUB_RES:.*]] = tensor.pack -// CHECK-SAME: %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] -// CHECK-SAME: into %[[CAST_OUT]] -// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] // CHECK: scf.yield %[[INSERT]] : tensor<2x8x8x2xf32> // CHECK: } // CHECK: return %[[RES0:.*]] : tensor<2x8x8x2xf32>