diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -117,6 +117,75 @@ let hasCanonicalizer = 1; } +def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", + [SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "tensor pad operation"; + let description = [{ + `linalg.pad_tensor` is an operation that pads the `source` tensor + with given `low` and `high` padding config. + + Example 1: + + ```mlir + %pad_value = ... : f32 + %0 = linalg.pad_tensor %0 low[1, 2] high[2, 3] { + ^bb0(%arg0 : index, %arg1 : index): + linalg.yield %pad_value : f32 + } : tensor to tensor + ``` + + Example 2: + ```mlir + %pad_value = ... : f32 + %0 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index): + linalg.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + ``` + + The first list contains the padding along the start of each + dimension, i.e `low`. The second list contains the padding along + the end of each dimension, i.e. `high`. The result tensor dimensions + are `low` + `dim` + `high` along that dimension. The number of + elements of `low` and `high` must match the rank of the input + tensor (which is also the rank of the output tensor). They can be + either a constant or a dynamic value. + + The region of the `pad_tensor` operation returns the value to use + for the padding. The arguments of the region represent the index + of the source being accessed. There should be as many arguments as + the rank of the `source` tensor. The value `yield`-ed by the + region is used as the value of the view at the given position. + }]; + + let arguments = (ins + AnyTensor:$source, + Variadic:$low, + Variadic:$high, + I64ArrayAttr:$static_low, + I64ArrayAttr:$static_high); + + let regions = (region AnyRegion:$region); + + let results = (outs AnyTensor:$result); + + let extraClassDeclaration = [{ + static StringRef getStaticLowAttrName() { + return "static_low"; + } + + static StringRef getStaticHighAttrName() { + return "static_high"; + } + + // Infer the shape of the result tensor given the static shapes + // and element type of the result tensor. + static RankedTensorType inferResultType(RankedTensorType sourceType, + ArrayRef staticLow, + ArrayRef staticHigh); + }]; +} + def Linalg_RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -915,6 +915,124 @@ ReplaceStaticShapeDims>(context); } +//===----------------------------------------------------------------------===// +// PadTensorOp +//===----------------------------------------------------------------------===// + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +static LogicalResult verify(PadTensorOp op) { + auto sourceType = op.source().getType().cast(); + auto resultType = op.result().getType().cast(); + auto expectedType = PadTensorOp::inferResultType( + sourceType, extractFromI64ArrayAttr(op.static_low()), + extractFromI64ArrayAttr(op.static_high())); + if (resultType != expectedType) { + return op.emitError("specified type ") + << resultType << " does not match the inferred type " + << expectedType; + } + + auto ®ion = op.region(); + if (!llvm::hasSingleElement(region)) + return op.emitOpError("expected region with 1 block"); + unsigned rank = resultType.getRank(); + Block &block = region.front(); + if (block.getNumArguments() != rank) + return op.emitError("expected the block to have ") << rank << " arguments"; + + // Note: the number and type of yield values are checked in the YieldOp. + for (auto en : llvm::enumerate(block.getArgumentTypes())) { + if (!en.value().isIndex()) + return op.emitOpError("expected block argument ") + << (en.index() + 1) << " to be an index"; + } + + return success(); +} + +RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, + ArrayRef staticLow, + ArrayRef staticHigh) { + unsigned rank = sourceType.getRank(); + assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); + assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); + + SmallVector resultShape; + for (auto i : llvm::seq(0, rank)) { + if (sourceType.isDynamicDim(i) || + staticLow[i] == ShapedType::kDynamicSize || + staticHigh[i] == ShapedType::kDynamicSize) { + resultShape.push_back(ShapedType::kDynamicSize); + } else { + int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; + resultShape.push_back(size); + } + } + + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +static ParseResult parsePadTensorOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType baseInfo; + SmallVector operands; + SmallVector types; + if (parser.parseOperand(baseInfo)) + return failure(); + + IndexType indexType = parser.getBuilder().getIndexType(); + SmallVector lowPadding, highPadding; + if (parser.parseKeyword("low") || + parseListOfOperandsOrIntegers(parser, result, + PadTensorOp::getStaticLowAttrName(), + ShapedType::kDynamicSize, lowPadding)) + return failure(); + if (parser.parseKeyword("high") || + parseListOfOperandsOrIntegers(parser, result, + PadTensorOp::getStaticHighAttrName(), + ShapedType::kDynamicSize, highPadding)) + return failure(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + SmallVector operandTypes, regionTypes; + if (parser.parseRegion(*region, regionOperands, regionTypes)) + return failure(); + result.addRegion(std::move(region)); + + Type srcType, dstType; + if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType)) + return failure(); + + if (parser.addTypeToList(dstType, result.types)) + return failure(); + + return failure( + parser.resolveOperand(baseInfo, srcType, result.operands) || + parser.resolveOperands(lowPadding, indexType, result.operands) || + parser.resolveOperands(highPadding, indexType, result.operands)); +} + +static void print(OpAsmPrinter &p, PadTensorOp op) { + p << op->getName().getStringRef() << ' '; + p << op.source(); + p << " low"; + printListOfOperandsOrIntegers(p, op.low(), op.static_low(), + ShapedType::isDynamic); + p << " high"; + printListOfOperandsOrIntegers(p, op.high(), op.static_high(), + ShapedType::isDynamic); + p.printRegion(op.region()); + p << " : " << op.source().getType() << " to " << op.getType(); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1557,6 +1675,13 @@ if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(op, cast(parentOp)); + if (auto padTensorOp = dyn_cast(parentOp)) { + return success( + op.getNumOperands() == 1 && + op.getOperand(0).getType() == + padTensorOp.getType().cast().getElementType()); + } + return op.emitOpError("expected parent op with LinalgOp interface"); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -617,3 +617,25 @@ memref into memref return %0 : memref } + +// ----- + +func @pad_result_type(%arg0: tensor, %arg1: index, %arg2: i32) -> tensor { + // expected-error @+1 {{specified type 'tensor' does not match the inferred type 'tensor}} + %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %arg2 : i32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + +func @pad_block_args(%arg0: tensor, %arg1: i32) -> tensor { + // expected-error @+1 {{op expected block argument 1 to be an index}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + linalg.yield %arg1 : i32 + } : tensor to tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -5,6 +5,39 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, + %pad_value: f32) -> tensor<6x?x?x?xf32> { + %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + linalg.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + return %0 : tensor<6x?x?x?xf32> +} +// CHECK-LABEL: func @pad_dynamic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[LOW:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[HIGH:[a-zA-Z0-9_]*]] +// CHECK: linalg.pad_tensor %[[ARG0]] +// CHECK-SAME: low[2, %[[LOW]], 3, 3] +// CHECK-SAME: high[3, 3, %[[HIGH]], 2] +// CHECK: : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + +// ----- + +func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> { + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg1 : index, %arg2 : index): + linalg.yield %pad_value : f32 + } : tensor<3x4xf32> to tensor<6x9xf32> + return %0 : tensor<6x9xf32> +} +// CHECK-LABEL: func @pad_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] +// CHECK: : tensor<3x4xf32> to tensor<6x9xf32> + +// ----- + func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return