Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2846,12 +2846,105 @@ } }; +struct FoldStaticPadding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padTensorOp, + PatternRewriter &rewriter) const override { + Value input = padTensorOp.getSource(); + if (!input.getType().isa()) + return failure(); + auto inputDims = input.getType().cast().getShape(); + auto inputRank = inputDims.size(); + + if (!padTensorOp.getResult().getType().isa()) + return failure(); + auto outputDims = + padTensorOp.getResult().getType().cast().getShape(); + + // Extract the static info from the high and low operands. + SmallVector constOperandsLow; + for (auto operand : padTensorOp.getLow()) { + APSInt intOp; + if (!matchPattern(operand, m_ConstantInt(&intOp))) { + constOperandsLow.push_back(ShapedType::kDynamic); + continue; + } + constOperandsLow.push_back(intOp.getExtValue()); + } + SmallVector constOperandsHigh; + for (auto operand : padTensorOp.getHigh()) { + APSInt intOp; + if (!matchPattern(operand, m_ConstantInt(&intOp))) { + constOperandsHigh.push_back(ShapedType::kDynamic); + continue; + } + constOperandsHigh.push_back(intOp.getExtValue()); + } + + SmallVector constLow(padTensorOp.getStaticLow()); + SmallVector constHigh(padTensorOp.getStaticHigh()); + + // Verify the op is well-formed. + if (inputDims.size() != outputDims.size() || + inputDims.size() != constLow.size() || + inputDims.size() != constHigh.size()) + return failure(); + + auto lowCount = 0; + auto highCount = 0; + for (size_t i = 0; i < inputRank; i++) { + if (constLow[i] == ShapedType::kDynamic) + constLow[i] = constOperandsLow[lowCount++]; + if (constHigh[i] == ShapedType::kDynamic) + constHigh[i] = constOperandsHigh[highCount++]; + } + + auto staticLow = ArrayRef(constLow); + auto staticHigh = ArrayRef(constHigh); + + // Calculate the output sizes with the static information. + SmallVector newOutDims; + for (size_t i = 0; i < inputRank; i++) { + if (outputDims[i] == ShapedType::kDynamic) { + newOutDims.push_back( + (staticLow[i] == ShapedType::kDynamic || + staticHigh[i] == ShapedType::kDynamic || + inputDims[i] == ShapedType::kDynamic + ? ShapedType::kDynamic + : inputDims[i] + staticLow[i] + staticHigh[i])); + } else { + newOutDims.push_back(outputDims[i]); + } + } + + if (SmallVector(outputDims) == newOutDims || + llvm::all_of(newOutDims, + [&](int64_t x) { return x == ShapedType::kDynamic; })) + return failure(); + + // Rewrite the op using the new static type. + auto newResultType = RankedTensorType::get( + newOutDims, padTensorOp.getType().getElementType()); + auto newOp = rewriter.create( + padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(), + padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold()); + + IRMapping mapper; + padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); + rewriter.replaceOpWithNewOp(padTensorOp, newResultType, + newOp); + + return success(); + } +}; + } // namespace void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldOrthogonalPaddings, FoldStaticPadding>(context); } /// Return the padding value of the PadOp if it constant. In this context, Index: mlir/test/Dialect/Tensor/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tensor/canonicalize.mlir +++ mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1111,6 +1111,29 @@ // ----- +// CHECK-LABEL: func @pad_fold_static( +// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PADDING:.*]] = arith.constant 4 : index +// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] +// CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] { +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor to tensor +func.func @pad_fold_static(%arg0: tensor) + -> tensor { + %cst = arith.constant 0.000000e+00 : f32 + %padding = arith.constant 4 : index + %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst: f32 + } : tensor to tensor + %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor into tensor + return %result : tensor +} + +// ----- + // CHECK-LABEL: func @pad_nofold_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK: %[[PAD:.*]] = tensor.pad