diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -826,6 +826,7 @@ //===----------------------------------------------------------------------===// def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { let summary = "tensor pad operation"; let description = [{ diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1812,6 +1812,41 @@ return RankedTensorType::get(inferredShape, sourceType.getElementType()); } +LogicalResult +PadOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + RankedTensorType resultType = getResultType(); + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + + // Given an OpFoldResult, return an index-typed value. + auto getIdxValue = [&](OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + return builder + .create(getLoc(), *getConstantIntValue(ofr)) + .getResult(); + }; + + for (auto dim : llvm::seq(0, resultType.getRank())) { + if (!resultType.isDynamicDim(dim)) { + // Static result dim. No need to compute anything. + reifiedReturnShapes[0][dim] = builder.create( + getLoc(), resultType.getDimSize(dim)); + continue; + } + + // Compute src size + low pad + high pad. + Value srcDimSize = + builder.createOrFold(getLoc(), source(), dim); + Value lowPadSize = getIdxValue(getMixedLowPad()[dim]); + Value highPadSize = getIdxValue(getMixedHighPad()[dim]); + reifiedReturnShapes[0][dim] = builder.createOrFold( + getLoc(), srcDimSize, + builder.createOrFold(getLoc(), lowPadSize, highPadSize)); + } + return success(); +} + void PadOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, ValueRange high, bool nofold,