diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -883,6 +883,28 @@ PatternRewriter &rewriter) const override; }; +using OptimizeCopyFn = + std::function; + +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and +/// InsertSliceOp. For now, only constant padding values are supported. +/// `OptimizeCopyFn` can be used to customize copying step optimization. +struct GeneralizePadTensorOpPattern : public OpRewritePattern { + GeneralizePadTensorOpPattern(MLIRContext *context, + OptimizeCopyFn optimizeCopyFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + optimizeCopyFn(optimizeCopyFn) {} + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; + +protected: + OptimizeCopyFn optimizeCopyFn; + Value createFillOrGenerateOp(PatternRewriter &rewriter, PadTensorOp padOp, + Value dest, + const SmallVector &dynSizes) const; +}; + /// Populates `patterns` with patterns that vectorize linalg.pad_tensor. /// These patterns are meant to apply in a complementary fashion. Benefits /// are used to encode a certain ordering of pattern application. To avoid diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -334,6 +334,7 @@ target.addDynamicallyLegalOp(isLegalOperation); RewritePatternSet patterns(&context); + patterns.add(patterns.getContext()); populateLinalgBufferizePatterns(typeConverter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -699,6 +699,95 @@ return success(); } +/// Filling `dest` using FillOp constant padding value if possible. +/// Otherwise, generate a tensor::GenerateOp. +Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( + PatternRewriter &rewriter, PadTensorOp padOp, Value dest, + const SmallVector &dynSizes) const { + auto padValue = padOp.getConstantPaddingValue(); + if (padValue) + return rewriter.create(padOp.getLoc(), padValue, dest).result(); + + // Fill could not be optimized: Lower to tensor::GenerateOp with region. + auto generateOp = rewriter.create( + padOp.getLoc(), padOp.getResultType(), dynSizes); + // Copy region to new op. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&generateOp.getRegion(), bvm); + // Rewrite linalg::YieldOp to tensor::YieldOp. + OpBuilder::InsertionGuard guard(rewriter); + auto yieldOp = + dyn_cast(generateOp.getRegion().front().getTerminator()); + assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); + assert(yieldOp.values().size() == 1); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.values()[0]); + return generateOp; +} + +LogicalResult +GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const { + // Given an OpFoldResult, return an index-typed value. + auto getIdxValue = [&](OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + return rewriter + .create( + padOp.getLoc(), ofr.get().cast().getInt()) + .getResult(); + }; + + auto resultType = padOp.getResultType(); + // Compute size of InitTensorOp. Any combination of static/dynamic is + // supported. + SmallVector dynSizes; + SmallVector staticSizes; + for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { + if (resultType.isDynamicDim(dim)) { + auto srcSize = rewriter.createOrFold(padOp.getLoc(), + padOp.source(), dim); + // Add low and high padding value. + auto plusLow = rewriter.createOrFold( + padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); + auto plusHigh = rewriter.createOrFold( + padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); + dynSizes.push_back(plusHigh); + } + staticSizes.push_back(resultType.getDimSize(dim)); + } + + // Init tensor and fill it with padding. + Value init = rewriter.create( + padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); + Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); + + // Try optimize the copy of source. + if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) + return success(); + + // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead + // for copying the PadOp source. + auto sourceType = padOp.getSourceType(); + // Compute size of source of PadTensorOp. + SmallVector srcSizes; + for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { + if (sourceType.isDynamicDim(dim)) { + srcSizes.push_back(rewriter.createOrFold( + padOp.getLoc(), padOp.source(), dim)); + } else { + srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); + } + } + // Strides of InsertSliceOp are all 1. + SmallVector strides(sourceType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); + + return success(); +} + /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute, /// it must be of type Integer. static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -682,104 +682,15 @@ /// If there is enough static type information, TransferReadOps and /// TransferWriteOps may be generated instead of InsertSliceOps. struct GenericPadTensorOpVectorizationPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadTensorOp padOp, - PatternRewriter &rewriter) const final { - // Given an OpFoldResult, return an index-typed value. - auto getIdxValue = [&](OpFoldResult ofr) { - if (auto val = ofr.dyn_cast()) - return val; - return rewriter.create( - padOp.getLoc(), getIntFromAttr(ofr.get())).getResult(); - }; - - auto resultType = padOp.getResultType(); - // Compute size of InitTensorOp. Any combination of static/dynamic is - // supported. - SmallVector dynSizes; - SmallVector staticSizes; - for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { - if (resultType.isDynamicDim(dim)) { - auto srcSize = rewriter.createOrFold( - padOp.getLoc(), padOp.source(), dim); - // Add low and high padding value. - auto plusLow = rewriter.createOrFold( - padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); - auto plusHigh = rewriter.createOrFold( - padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); - dynSizes.push_back(plusHigh); - } - staticSizes.push_back(resultType.getDimSize(dim)); - } - - // Init tensor and fill it with padding. - Value init = rewriter.create( - padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); - Value fill = tryVectorizeFill(rewriter, padOp, init, dynSizes); - - // Try vectorizing the copy of source. - if (tryVectorizeCopy(rewriter, padOp, fill).succeeded()) - return success(); - - // Neither source type nor PadTensorOp result type have static shape. Such - // PadTensorOps cannot be vectorized. Generate a InsertSliceOp instead - // for copying the PadOp source. - - auto sourceType = padOp.getSourceType(); - // Compute size of source of PadTensorOp. - SmallVector srcSizes; - for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { - if (sourceType.isDynamicDim(dim)) { - srcSizes.push_back(rewriter.createOrFold( - padOp.getLoc(), padOp.source(), dim)); - } else { - srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); - } - } - // Strides of InsertSliceOp are all 1. - SmallVector strides(sourceType.getRank(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp( - padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); - - return success(); - } - - /// Vectorize the filling of `dest`. This is possible if the padOp is padding - /// with a constant value. Otherwise, generate a tensor::GenerateOp. - Value tryVectorizeFill(PatternRewriter &rewriter, PadTensorOp padOp, - Value dest, const SmallVector &dynSizes) const { - // Fill can be vectorized if padValue is a constant. (If there is enough - // static type information, the FillOp will be vectorized by another - // pattern.) - auto padValue = padOp.getConstantPaddingValue(); - if (padValue) - return rewriter.create(padOp.getLoc(), padValue, dest).result(); - - // Fill could not be vectorized: Lower to tensor::GenerateOp with region. - auto generateOp = rewriter.create( - padOp.getLoc(), padOp.getResultType(), dynSizes); - // Copy region to new op. - BlockAndValueMapping bvm; - padOp.region().cloneInto(&generateOp.getRegion(), bvm); - // Rewrite linalg::YieldOp to tensor::YieldOp. - OpBuilder::InsertionGuard guard(rewriter); - auto yieldOp = dyn_cast( - generateOp.getRegion().front().getTerminator()); - assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); - assert(yieldOp.values().size() == 1); - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.values()[0]); - return generateOp; - } - + : public GeneralizePadTensorOpPattern { + GenericPadTensorOpVectorizationPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {} /// Vectorize the copying of a PadTensorOp's source. This is possible if each /// dimension size is statically know in the source type or the result type /// (or both). - LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, PadTensorOp padOp, - Value dest) const { + static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, + PadTensorOp padOp, Value dest) { auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck --check-prefix=CHECK %s + +// CHECK-LABEL: func @generalize_pad_tensor_static_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { +// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[FILL:.*]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x32x32x1xf32> -> tensor<1x32x32x1xf32> +// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32> +// CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32> +func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32> + return %0 : tensor<1x32x32x1xf32> +} + +// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C3:.*]] = constant 3 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM2:.*]] = addi %[[OFFSET]], %[[C2]] : index +// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM3:.*]] = addi %[[DIM3]], %[[OFFSET]] : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32> +// CHECK: %[[FILL:.*]] = linalg.fill(%[[CST]], %[[INIT]]) : f32, tensor<4x?x?x?xf32> -> tensor<4x?x?x?xf32> +// CHECK: %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> +// CHECK: %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> +// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32> +// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32> +// CHECK: } +func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %out = linalg.pad_tensor %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32> + return %out : tensor<4x?x?x?xf32> +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +func @main() { + %const = constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]]> : tensor<1x2x3xf32> + %dynamic = tensor.cast %const: tensor<1x2x3xf32> to tensor<1x?x3xf32> + %offset = constant 2 : index + %cst = constant 2.3 : f32 + %c0 = constant 0 : index + %out = linalg.pad_tensor %dynamic low[%c0, %offset, %c0] high[%c0, %c0, %offset] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x?x3xf32> to tensor<1x?x?xf32> + %unranked = tensor.cast %out: tensor<1x?x?xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 3 offset = 0 sizes = [1, 4, 5] strides = [20, 5, 1] data = + // CHECK-NEXT{LITERAL}: [[[2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [1, 2, 3, 2.3, 2.3], + // CHECK-NEXT: [2, 3, 4, 2.3, 2.3]]] + + return +} + +func private @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -97,6 +97,10 @@ *this, "test-transform-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; + Option testGeneralizePadTensor{ + *this, "test-generalize-pad-tensor", + llvm::cl::desc("Test transform pad tensor by copying with generic ops"), + llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " @@ -530,6 +534,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + patterns.add(funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -614,6 +624,8 @@ return applyLinalgToVectorPatterns(getFunction()); if (testTransformPadTensor) return applyPadTensorToGenericPatterns(getFunction()); + if (testGeneralizePadTensor) + return applyGeneralizePadTensorPatterns(getFunction()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getFunction()); if (testAffineMinSCFCanonicalizationPatterns)