diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2839,12 +2839,77 @@ } }; +struct FoldStaticPadding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padTensorOp, + PatternRewriter &rewriter) const override { + Value input = padTensorOp.getSource(); + if (!input.getType().isa()) + return failure(); + auto inputDims = input.getType().cast().getShape(); + auto inputRank = inputDims.size(); + + // Extract the static info from the high and low operands + auto operands = padTensorOp.getOperands(); + SmallVector constOperands; + for (auto op : operands) { + if (op.getType().isa()) + continue; + APSInt intOp; + if (!matchPattern(op, m_ConstantInt(&intOp))) { + constOperands.push_back(-1); + continue; + } + constOperands.push_back(intOp.getExtValue()); + } + + if (!padTensorOp.getResult().getType().isa()) + return failure(); + auto outputDims = + padTensorOp.getResult().getType().cast().getShape(); + + // Verify the op is well-formed + if (inputDims.size() != outputDims.size() || + inputDims.size() != (constOperands.size() / 2)) + return failure(); + + auto staticLow = ArrayRef(constOperands.begin(), + constOperands.begin() + inputRank); + auto staticHigh = ArrayRef(constOperands.begin() + inputRank, + constOperands.end()); + + // Calculate the output sizes with the static information + SmallVector newOutDims; + for (size_t i = 0; i < inputRank; i++) + newOutDims.push_back((staticLow[i] < 0 || staticHigh[i] < 0 + ? -1 + : inputDims[i] + staticLow[i] + staticHigh[i])); + if (SmallVector(outputDims) == newOutDims) + return failure(); + + // Rewrite the op using the new static type + auto newResultType = RankedTensorType::get( + newOutDims, padTensorOp.getType().getElementType()); + auto newOp = rewriter.create( + padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(), + padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold()); + + IRMapping mapper; + padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); + rewriter.replaceOpWithNewOp( + padTensorOp, padTensorOp.getResultType(), newOp); + + return success(); + } +}; + } // namespace void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldOrthogonalPaddings, FoldStaticPadding>(context); } /// Return the padding value of the PadOp if it constant. In this context,