diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1298,6 +1298,12 @@ "ArrayRef":$low, "ArrayRef":$high, CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, + // Build a PadOp with constant padding, mixed static and dynamic entries + // and custom result type. If the type passed is nullptr, it is inferred. + OpBuilder<(ins "Type":$resultType, "Value":$source, + "ArrayRef":$low, "ArrayRef":$high, + "Value":$constantPadValue, CArg<"bool", "false">:$nofold, + CArg<"ArrayRef", "{}">:$attrs)> ]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2407,7 +2407,6 @@ Value source, ArrayRef low, ArrayRef high, bool nofold, ArrayRef attrs) { - assert(resultType.isa()); auto sourceType = source.getType().cast(); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; @@ -2422,12 +2421,32 @@ if (!resultType) { resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh); } + assert(resultType.isa()); build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } +void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value source, ArrayRef low, + ArrayRef high, Value constantPadValue, + bool nofold, ArrayRef attrs) { + build(b, result, resultType, source, low, high, nofold, attrs); + + // Add a region and a block to yield the pad value. + Region *region = result.regions[0].get(); + region->push_back(new Block); + Block &body = region->front(); + for (auto dim : llvm::seq( + 0, source.getType().cast().getRank())) { + body.addArgument(b.getIndexType(), result.location); + } + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(&body); + b.create(result.location, constantPadValue); +} + llvm::SmallBitVector PadOp::getPaddedDims() { llvm::SmallBitVector paddedDims(getSourceType().getRank()); auto extractPaddedDims = [&](ArrayRef paddingWidths) {