diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1151,15 +1151,6 @@ PatternRewriter &rewriter) const override; }; -/// tensor::PadOp is not canonicalized away yet, so we provide a -/// transformation to `linalg.generic`. -struct PadOpTransformationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const override; -}; - using OptimizeCopyFn = std::function; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -139,12 +139,6 @@ opOperand->get(), paddingValue, nofold); } -static SmallVector -getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, - utils::IteratorType::parallel); -} - //===----------------------------------------------------------------------===// // Transformations exposed as functional-style API calls. //===----------------------------------------------------------------------===// @@ -1028,71 +1022,6 @@ return vectorizeCopy(rewriter, copyOp); } -/// -/// Pattern to rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to -/// initialize with pad_val) and GenericOp (to copy contents). -/// -LogicalResult -PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const { - - auto inputShapedType = cast(padOp.getSource().getType()); - auto resultShapedType = cast(padOp.getResult().getType()); - - // Bail on non-static shapes. - if (!inputShapedType.hasStaticShape()) - return failure(); - if (!resultShapedType.hasStaticShape()) - return failure(); - - // Only support padding with a constant for now, i.e. either: - // 1. A BBarg from a different block. - // 2. A value defined outside of the current block. - Block &block = padOp.getRegion().front(); - auto yieldOp = cast(block.getTerminator()); - Value padValue = yieldOp.getValue(); - Operation *definingOp = padValue.getDefiningOp(); - if (definingOp && definingOp->getBlock() == &block) - return failure(); - if (!definingOp && cast(padValue).getOwner() == &block) - return failure(); - - // Create tensor with the padded shape - Location loc = padOp.getLoc(); - SmallVector indices(resultShapedType.getRank(), - rewriter.create(loc, 0)); - Value emptyTensor = rewriter.create( - loc, resultShapedType.getShape(), resultShapedType.getElementType()); - - // Initialize tensor with the pad value - Value tmpTensor = rewriter - .create(loc, ValueRange{padValue}, - ValueRange{emptyTensor}) - .result(); - - // Copy original contents into new tensor - // Uses linalg.generic, but could be done with tensor.insert_slice - SmallVector outputExprs; - for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { - outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + - padOp.getStaticLow()[i]); - } - - SmallVector transferMaps = { - rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), - AffineMap::get(resultShapedType.getRank(), - /*symbolCount=*/0, outputExprs, rewriter.getContext())}; - - rewriter.replaceOpWithNewOp( - padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps, - getNParallelLoopsAttrs(resultShapedType.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }); - - return success(); -} - /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. Value GeneralizePadOpPattern::createFillOrGenerateOp( diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor" %s | FileCheck --check-prefix=CHECK %s - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)> -// CHECK-LABEL: func @pad_tensor_with_memrefs -func.func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = bufferization.to_tensor %arg0 : memref<1x28x28x1xf32> - %1 = tensor.pad %0 low[1, 1, 1, 2] high[0, 2, 2, 0] { - ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): - tensor.yield %cst : f32 - } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32> - %2 = bufferization.to_memref %1 : memref<2x31x31x3xf32> - return %2 : memref<2x31x31x3xf32> -} - -// CHECK: linalg.fill -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] - -// ----- - -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)> -// CHECK-LABEL: func @pad_tensor_no_memrefs -func.func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.pad %arg0 low[1, 2, 2] high[0, 2, 2] { - ^bb0(%arg1: index, %arg2: index, %arg3: index): - tensor.yield %cst : f32 - } : tensor<1x28x28xf32> to tensor<2x32x32xf32> - return %0 : tensor<2x32x32xf32> -} - -// CHECK: linalg.fill -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]] - -// ----- - -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)> -// CHECK-LABEL: func @pad_tensor_detailed -func.func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] { - ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): - tensor.yield %cst : f32 - } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32> - return %0 : tensor<1x32x32x1xf32> -} - -// CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> -// CHECK: %[[CTE:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[TMP:.+]] = tensor.empty() : tensor<1x32x32x1xf32> -// CHECK: %[[R1c:.+]] = linalg.fill -// CHECK: %[[R2c:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK: ins(%{{.*}} : tensor<1x28x28x1xf32>) outs(%{{.*}} : tensor<1x32x32x1xf32>) -// CHECK: ^bb0(%[[VAL:.+]]: f32, %{{.*}}: f32) -// CHECK: linalg.yield %[[VAL]] : f32 -// CHECK: return %[[R2c:.+]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -70,10 +70,6 @@ llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; - Option testTransformPadTensor{ - *this, "test-transform-pad-tensor", - llvm::cl::desc("Test transform pad tensor by copying with generic ops"), - llvm::cl::init(false)}; Option testGeneralizePadTensor{ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), @@ -163,12 +159,6 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -225,8 +215,6 @@ return applyVectorTransferForwardingPatterns(getOperation()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getOperation()); - if (testTransformPadTensor) - return applyPadTensorToGenericPatterns(getOperation()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getOperation()); if (testGeneralizeTensorPackOp)