diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -224,6 +224,13 @@ // size is met). static linalg::PadTensorOp createPadHighOp( Type type, Value source, Value pad, Location loc, OpBuilder & builder); + + // Return a PadTensorOp that pads `source to `type` size with `pad` value. + // I.e., a block will be created and the `pad` value will be yield-ed + // directly. If the type passed is nullptr, it is inferred. + static linalg::PadTensorOp createPadScalarOp( + Type type, Value source, Value pad, ArrayRef low, + ArrayRef high, Location loc, OpBuilder & builder); }]; let builders = [ @@ -234,7 +241,7 @@ // Build a PadTensorOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, CArg<"ArrayRef", "{}">:$attrs)>, - // Build a PadTensorOp with with mixed static and dynamic entries and custom + // Build a PadTensorOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. OpBuilderDAG<(ins "Type":$resultType, "Value":$source, "ArrayRef":$low, "ArrayRef":$high, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -780,6 +780,24 @@ b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); } +PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, + ArrayRef low, + ArrayRef high, + Location loc, OpBuilder &builder) { + auto padTensorOp = + builder.create(loc, type, source, low, high); + int rank = padTensorOp.getResultType().getRank(); + SmallVector blockArgTypes; + blockArgTypes.assign(rank, builder.getIndexType()); + auto ®ion = padTensorOp.region(); + // `builder.createBlock` changes the insertion point within the block. Create + // a guard to reset the insertion point of the builder after it is destroyed. + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(®ion, region.end(), blockArgTypes); + builder.create(loc, pad); + return padTensorOp; +} + PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, Location loc, OpBuilder &builder) { SmallVector low, high; @@ -794,17 +812,8 @@ high.push_back(highValue); low.push_back(builder.createOrFold(loc, 0)); } - auto padTensorOp = - builder.create(loc, type, source, low, high); - SmallVector blockArgTypes; - blockArgTypes.assign(rank, builder.getIndexType()); - auto ®ion = padTensorOp.region(); - // `builder.createBlock` changes the insertion point within the block. Create - // a guard to reset the insertion point of the builder after it is destroyed. - OpBuilder::InsertionGuard guard(builder); - builder.createBlock(®ion, region.end(), blockArgTypes); - builder.create(loc, pad); - return padTensorOp; + return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc, + builder); } //===----------------------------------------------------------------------===//