diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2846,12 +2846,85 @@ } }; +struct FoldStaticPadding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padTensorOp, + PatternRewriter &rewriter) const override { + Value input = padTensorOp.getSource(); + if (!input.getType().isa()) + return failure(); + auto inputDims = input.getType().cast().getShape(); + auto inputRank = inputDims.size(); + + if (!padTensorOp.getResult().getType().isa()) + return failure(); + auto outputDims = + padTensorOp.getResult().getType().cast().getShape(); + + // Extract the static info from the high and low operands. + SmallVector constOperands; + for (auto operand : padTensorOp.getOperands()) { + if (operand.getType().isa()) + continue; + APSInt intOp; + if (!matchPattern(operand, m_ConstantInt(&intOp))) { + constOperands.push_back(ShapedType::kDynamic); + continue; + } + constOperands.push_back(intOp.getExtValue()); + } + + // Verify the op is well-formed. + if (inputDims.size() != outputDims.size() || + inputDims.size() != (constOperands.size() / 2)) + return failure(); + + auto staticLow = ArrayRef(constOperands.begin(), + constOperands.begin() + inputRank); + auto staticHigh = ArrayRef(constOperands.begin() + inputRank, + constOperands.end()); + + // Calculate the output sizes with the static information. + SmallVector newOutDims; + for (size_t i = 0; i < inputRank; i++) { + if (outputDims[i] == ShapedType::kDynamic) { + newOutDims.push_back( + (staticLow[i] < 0 || staticHigh[i] < 0 + ? ShapedType::kDynamic + : inputDims[i] + staticLow[i] + staticHigh[i])); + } else { + newOutDims.push_back(outputDims[i]); + } + } + + if (SmallVector(outputDims) == newOutDims || + llvm::all_of(newOutDims, + [&](int64_t x) { return x == ShapedType::kDynamic; })) + return failure(); + + // Rewrite the op using the new static type. + auto newResultType = RankedTensorType::get( + newOutDims, padTensorOp.getType().getElementType()); + auto newOp = rewriter.create( + padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(), + padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold()); + + IRMapping mapper; + padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); + rewriter.replaceOpWithNewOp(padTensorOp, newResultType, + newOp); + + return success(); + } +}; + } // namespace void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldOrthogonalPaddings, FoldStaticPadding>(context); } /// Return the padding value of the PadOp if it constant. In this context,