diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,10 @@ const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda = nullptr); +//===----------------------------------------------------------------------===// +// tensor.pad patterns +//===----------------------------------------------------------------------===// + /// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)). struct ExtractSliceOfPadTensorSwapPattern : public OpRewritePattern { @@ -1338,6 +1342,12 @@ ControlFn controlFn; }; +/// Populates patterns to make tensor.pad result shape static if possible. +/// This can be used after ExtractSliceOfPadTensorSwapPattern to expose static +/// information for further transformations like vectorization. +void populateConcretizePadResultShapePatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // Helper classes for type list expansion. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -935,6 +935,113 @@ return success(); } +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + +namespace { +/// Concretizes tensor.pad op's result shape if its source op implements +/// OffsetSizeAndStrideOpInterface. For example, pad(extract_slice). +struct ConcretizePadResultShape final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // If the result shape is already static, then nothing to do. + if (padOp.getResultType().hasStaticShape()) + return failure(); + + int rank = padOp.getResultType().getRank(); + SmallVector staticShape; + staticShape.reserve(rank); + + auto sourceIfxOp = dyn_cast_or_null( + padOp.source().getDefiningOp()); + if (!sourceIfxOp) + return failure(); + + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector source = sourceIfxOp.getMixedSizes(); + SmallVector highPad = padOp.getMixedHighPad(); + + MLIRContext *context = padOp.getContext(); + Location loc = padOp.getLoc(); + + AffineExpr sym0, sym1, sym2; + bindSymbols(context, sym0, sym1, sym2); + auto addMap = AffineMap::get(0, 3, {sym0 + sym1 + sym2}, context); + + SmallVector valueSizes; + for (int dimIndex = 0; dimIndex < rank; ++dimIndex) { + valueSizes.clear(); + valueSizes.push_back(getAsIndexValue(lowPad[dimIndex], rewriter, loc)); + valueSizes.push_back(getAsIndexValue(source[dimIndex], rewriter, loc)); + valueSizes.push_back(getAsIndexValue(highPad[dimIndex], rewriter, loc)); + + // The pad op's result shape is low padding + source size + high padding. + // Try to see if we can get a constant number by composing and + // canonicalizing the result. We use affine mechanisms here because + // generating arithmetic add ops over dim ops won't work, given they are + // SSA values that would need invoking other patterns to simplify. We + // cannot invoke patterns in patterns. + AffineMap map = addMap; + fullyComposeAffineMapAndOperands(&map, &valueSizes); + canonicalizeMapAndOperands(&map, &valueSizes); + + auto cstExpr = map.getResult(0).dyn_cast(); + // Specially handle the case where we have both dimensions and symbols and + // they map to the same value, e.g.: + // affine_map<(d0, s0) -> (d0 - s0 + 4)>(%v, %v). + // Due to the restrictions over dimensions and symbols, the above won't + // simplify. Try to change dimensions for symbols for such cases. + if (!cstExpr && llvm::is_splat(valueSizes)) { + int numDims = map.getNumDims(); + int numSyms = map.getNumSymbols(); + DenseMap dimToSymMap; + for (int i = 0; i < numDims; ++i) { + dimToSymMap[rewriter.getAffineDimExpr(i)] = + rewriter.getAffineSymbolExpr(numSyms + i); + } + map = map.replace(dimToSymMap, /*numResultDims=*/0, + /*numResultSyms=*/numDims + numSyms); + + canonicalizeMapAndOperands(&map, &valueSizes); + cstExpr = map.getResult(0).dyn_cast(); + } + if (!cstExpr) + return failure(); + + staticShape.push_back(cstExpr.getValue()); + } + + auto resultType = RankedTensorType::get( + staticShape, padOp.getResultType().getElementType(), + padOp.getResultType().getEncoding()); + + rewriter.updateRootInPlace(padOp, + [&]() { padOp.result().setType(resultType); }); + return success(); + } +}; +} // namespace + +void linalg::populateConcretizePadResultShapePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} + namespace { // The following are patterns for downscaling convolution ops with size-1 // window dimensions. diff --git a/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir b/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-concretize-pad-result-shape -allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: func @only_high_pad +func @only_high_pad(%tensor: tensor<1x224x224x3xf32>, %arg0: index, %arg1: index) { + %cst = arith.constant 0.0 : f32 + %0 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %1 = affine.min affine_map<(d0) -> (d0 * 2 + 3, 224)>(%arg0) + %2 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%1, %arg0) + %3 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 3)>(%1, %arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %5 = affine.min affine_map<(d0) -> (d0 * 2 + 9, 224)>(%arg1) + %6 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%5, %arg1) + %7 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 9)>(%5, %arg1) + %8 = tensor.extract_slice %tensor[0, %0, %4, 0][1, %2, %6, 3][1, 1, 1, 1] : tensor<1x224x224x3xf32> to tensor<1x?x?x3xf32> + // CHECK: tensor.pad + %pad = tensor.pad %8 low[0, 0, 0, 0] high[0, %3, %7, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + // CHECK: tensor.yield + tensor.yield %cst : f32 + // CHECK-NEXT: tensor<1x?x?x3xf32> to tensor<1x3x9x3xf32> + } : tensor<1x?x?x3xf32> to tensor<1x?x?x3xf32> + "dialect.use"(%pad) : (tensor<1x?x?x3xf32>) -> () +} + +// ----- + +// CHECK-LABEL: func @both_low_and_high_pad +func @both_low_and_high_pad(%tensor: tensor<1x56x56x144xf32>, %arg0: index, %arg1: index, %arg2: index) { + %cst = arith.constant 0.0 : f32 + %0 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg0) + %1 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg0) + %2 = affine.min affine_map<(d0) -> (d0, 56)>(%1) + %3 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg0) + %4 = affine.min affine_map<(d0) -> (d0, 56)>(%3) + %5 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%4, %2) + %6 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%0, %4, %2) + %7 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg1) + %8 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg1) + %9 = affine.min affine_map<(d0) -> (d0, 56)>(%8) + %10 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg1) + %11 = affine.min affine_map<(d0) -> (d0, 56)>(%10) + %12 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%11, %9) + %13 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%7, %11, %9) + %14 = tensor.extract_slice %tensor[0, %2, %9, %arg2][1, %5, %12, 16][1, 1, 1, 1] : tensor<1x56x56x144xf32> to tensor<1x?x?x16xf32> + // CHECK: tensor.pad + %pad = tensor.pad %14 low[0, %0, %7, 0] high[0, %6, %13, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): // no predecessors + // CHECK: tensor.yield + tensor.yield %cst : f32 + // CHECK-NEXT: tensor<1x?x?x16xf32> to tensor<1x4x4x16xf32> + } : tensor<1x?x?x16xf32> to tensor<1x?x?x16xf32> + "dialect.use"(%pad) : (tensor<1x?x?x16xf32>) -> () +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -103,6 +103,11 @@ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; + Option testConcretizePadResultShape{ + *this, "test-concretize-pad-result-shape", + llvm::cl::desc( + "Test patterns to make tensor.pad result shape static when possible"), + llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " @@ -564,6 +569,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyConcretizeTensorPadResultShapePatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateConcretizePadResultShapePatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -712,6 +723,8 @@ return applyPadTensorToGenericPatterns(getOperation()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getOperation()); + if (testConcretizePadResultShape) + return applyConcretizeTensorPadResultShapePatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); if (testTiledLoopPeeling.hasValue())