diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -871,6 +871,15 @@ // Op-specific patterns. //===----------------------------------------------------------------------===// +/// PadTensorOp is not canonicalized away yet, so we provide a transformation to +/// `linalg.generic`. +struct PadTensorOpTransformationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; +}; + /// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`, /// it needs a specific pattern to vectorize. struct PadTensorOpVectorizationPattern : public OpRewritePattern { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -637,3 +637,68 @@ return failure(); } + +static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, getParallelIteratorTypeName()); +} + +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize +/// with pad_val) and GenericOp (to copy contents). +LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( + linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { + + auto inputShapedType = padOp.source().getType().cast(); + auto resultShapedType = padOp.result().getType().cast(); + + // Bail on non-static shapes. + if (!inputShapedType.hasStaticShape()) + return failure(); + if (!resultShapedType.hasStaticShape()) + return failure(); + + // Only support padding with a constant for now, i.e. either: + // 1. A BBarg from a different block. + // 2. A value defined outside of the current block. + Block &block = padOp.region().front(); + auto yieldOp = cast(block.getTerminator()); + assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); + Value padValue = yieldOp.values().front(); + Operation *definingOp = padValue.getDefiningOp(); + if (definingOp && definingOp->getBlock() == &block) + return failure(); + if (!definingOp && padValue.cast().getOwner() == &block) + return failure(); + + // Create tensor with the padded shape + Location loc = padOp.getLoc(); + SmallVector indices(resultShapedType.getRank(), + rewriter.create(loc, 0)); + Value initTensor = rewriter.create( + loc, resultShapedType.getShape(), resultShapedType.getElementType()); + + // Initialize tensor with the pad value + Value tmpTensor = + rewriter.create(loc, initTensor, padValue).result(); + + // Copy original contents into new tensor + // Uses linalg.generic, but could be done with std.subtensor_insert + SmallVector outputExprs; + for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { + outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + + padOp.static_low()[i].cast().getInt()); + } + + SmallVector transferMaps = { + rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), + AffineMap::get(resultShapedType.getRank(), + /*symbolCount=*/0, outputExprs, rewriter.getContext())}; + + rewriter.replaceOpWithNewOp( + padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, + getNParallelLoopsAttrs(resultShapedType.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + return success(); +} diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor" %s | FileCheck --check-prefix=CHECK %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)> +// CHECK-LABEL: func @pad_tensor_with_memrefs +func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = memref.tensor_load %arg0 : memref<1x28x28x1xf32> + %1 = linalg.pad_tensor %0 low[1, 1, 1, 2] high[0, 2, 2, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32> + %2 = memref.buffer_cast %1 : memref<2x31x31x3xf32> + return %2 : memref<2x31x31x3xf32> +} + +// CHECK: linalg.fill +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] + +// ----- + +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)> +// CHECK-LABEL: func @pad_tensor_no_memrefs +func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.pad_tensor %arg0 low[1, 2, 2] high[0, 2, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x28x28xf32> to tensor<2x32x32xf32> + return %0 : tensor<2x32x32xf32> +} + +// CHECK: linalg.fill +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]] + +// ----- + +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)> +// CHECK-LABEL: func @pad_tensor_detailed +func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32> + return %0 : tensor<1x32x32x1xf32> +} + +// CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> +// CHECK: %[[CTE:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[TMP:.+]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[R1c:.+]] = linalg.fill +// CHECK: %[[R2c:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK: ins(%arg0 : tensor<1x28x28x1xf32>) outs(%1 : tensor<1x32x32x1xf32>) +// CHECK: ^bb0(%[[VAL:.+]]: f32, %arg2: f32) +// CHECK: linalg.yield %[[VAL]] : f32 +// CHECK: return %[[R2c:.+]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -87,6 +87,10 @@ Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; + Option testTransformPadTensor{ + *this, "test-transform-pad-tensor", + llvm::cl::desc("Test transform pad tensor by copying with generic ops"), + llvm::cl::init(false)}; ListOption tileSizesForPadding{ *this, "tile-sizes-for-padding", llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, @@ -508,6 +512,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyPadTensorToGenericPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + patterns.add(funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { RewritePatternSet foldPattern(funcOp.getContext()); foldPattern.add(funcOp.getContext()); @@ -583,6 +593,8 @@ return applyVectorTransferForwardingPatterns(getFunction()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getFunction()); + if (testTransformPadTensor) + return applyPadTensorToGenericPatterns(getFunction()); if (testAffineMinSCFCanonicalizationPatterns) return applyAffineMinSCFCanonicalizationPatterns(getFunction()); if (testTileAndPadPattern)