diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -32,6 +32,127 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def Linalg_PadOp : Linalg_Op<"padded_view", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "memref padded view operation"; + let description = [{ + The pad operation provides a view of the underlying buffer with + padding. For example + + ```mlir + %pad_value = ... : f32 + %1 = linalg.padded_view %0[1, 2] [2, 3] { + ^bb0(%arg0 : index, %arg1 : index): + linalg.yield %pad_value : f32 + } : memref to memref + ``` + + The first `[ ]` contains the padding along the start of each + dimension, i.e `low_padding`. The second `[ ]` contains the padding + along the end of each dimension, i.e. `high_padding`. The result + memref dimensions are `low_padding` + `dim` + `high_padding` along + that dimension. The number of elements of low_padding and + high_padding must match the rank of the input memref (which is + also the rank of the output memref). + + The region of the `padded_view` operation returns the value to use + for the padding. The arguments of the region represent the index + of the source being accessed. There should be as many arguments as + the rank of the `source` memref. The value `yield`-ed by the + region is used as the value of the view at the given position. + + Note that the operation is returning a "view" of the underlying + buffer and not a newly allocated buffer. So using the padded_view + as the result buffer of an operation results in the update of the + `source` memref. + + When you read from the view, if you are not in the padding, you get + the value from the source buffer. If you are in the padding you get + the padded value specified using the region of the operation. + + When you write to the view, if you are not in the padding, you + write to the source buffer. If you are in the padding the write is + ignored. + + A valid lowering of `padded_view` is to fold it with its + consumers. For example, + + ```mlir + %pad_value = ... : f32 + %1 = linalg.padded_view %0[1, 2] [2, 3] { + ^bb0(%arg0 : index, %arg1 : index): + linalg.yield %pad_value : f32 + } : memref to memref + scf.for %iv0 = ... { + scf.for %iv1 = ... { + %2 = load %1[%iv0, %iv1] : memref + } + } + ``` + + can be folded to + + ```mlir + %pad_value = .... : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %0, %c0 + %d1 = dim %0, %c1 + scf.for %iv0 = ... { + scf.for %iv1 = ... { + %a = subi %iv0, %c1 : index + %b = subi %iv1, %c2 : index + %cond1_y = cmpi "slt", %a, %d0 : (index, index) -> i1 + %cond2_y = cmpi "sge", %a, %c0 : (index, index) -> i1 + %cond_y = and %cond1_y, %cond2_y : i1 + %cond1_x = cmpi "slt", %b, %d1 : (index, index) -> i1 + %cond2_x = cmpi "sge", %b, %c0 : (index, index) -> i1 + %cond_x = and %cond1_x, %cond2_x : i1 + %cond = and %cond_x, %cond_y : i1 + %2 = scf.if %cond -> (f32) { + %t = load %1[%a, %b] : f32 + scf.yield %t : f32 + } else { + scf.yield %pad_value : f32 + } + } + } + ``` + + Another valid lowering is to create an allocation for the result + of the pad and copy the `source` memref into a subview of the + allocation. + + ``` + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %d0 = dim %0, %c0 + %d1 = dim %0, %c1 + %t1 = addi %c1, %d0 : index + %t2 = addi %t1, %c2 : index + %t3 = addi %c2, %d1 : index + %t4 = addi %t3, %c3 : index + %1 = alloc(%t2, %t4) : memref + %2 = subview %0[%c1, %c2][%d0, %d1][1, 1] : memref into memref + linalg.copy(%0, %2) + ``` + + For the result, the copy would have to be done after all its uses. + }]; + + let arguments = (ins + AnyMemRef:$source, + I64ElementsAttr:$lowPadding, + I64ElementsAttr:$highPadding); + + let regions = (region AnyRegion:$region); + + let results = (outs AnyMemRef:$result); +} + def Linalg_RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -450,6 +450,21 @@ return success(); } +template <> +LogicalResult BlockArgsVerifier::verify(PadOp op, Block &block) { + unsigned rank = op.getType().cast().getRank(); + if (block.getNumArguments() != rank) + return op.emitError("expected the block to have ") << rank << " arguments"; + + // Note: the number and type of yield values are checked in the YieldOp. + for (auto en : llvm::enumerate(block.getArgumentTypes())) { + if (!en.value().isIndex()) + return op.emitOpError("expected block argument ") + << (en.index() + 1) << " to be an index"; + } + return success(); +} + template struct AnnotationsVerifier { static LogicalResult verify(GenericOpType op) { return success(); } @@ -550,6 +565,137 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(PadOp op) { + ShapedType sourceType = op.source().getType().cast(); + // Check that the rank of the padding attributes is 1. + ShapedType lowPaddingType = op.lowPadding().getType(); + ShapedType highPaddingType = op.highPadding().getType(); + if (lowPaddingType.getRank() != 1) + return op.emitError("expected lowPadding to be of rank 1"); + if (highPaddingType.getRank() != 1) + return op.emitError("expected highPadding to be of rank 1"); + if (sourceType.getRank() != lowPaddingType.getShape()[0]) + return op.emitError("expected size of lowPadding to match rank of source"); + if (sourceType.getRank() != highPaddingType.getShape()[0]) { + return op.emitError("expected size of highPadding to match rank of source"); + } + + ShapedType resultType = op.result().getType().cast(); + if (sourceType.getRank() != resultType.getRank()) + return op.emitError("expected result rank to match the rank of the source"); + + // Check that the shape matches. For static shape, shape of the output must be + // lowPadding + dim + highPadding. + auto lowPadding = op.lowPadding().getValues(); + auto highPadding = op.highPadding().getValues(); + auto checkAllPositive = [](int64_t v) { return v >= 0; }; + if (!llvm::any_of(lowPadding, checkAllPositive)) + return op.emitError("low padding values cannot be negetive"); + if (!llvm::any_of(highPadding, checkAllPositive)) + return op.emitError("high padding values cannot be negetive"); + + auto sourceShape = sourceType.getShape(); + auto resultShape = resultType.getShape(); + for (auto en : llvm::enumerate(llvm::zip(lowPadding, highPadding))) { + int dim = en.index(); + if (sourceShape[dim] == ShapedType::kDynamicSize || + resultShape[dim] == ShapedType::kDynamicSize) + continue; + int64_t low = std::get<0>(en.value()); + int64_t high = std::get<1>(en.value()); + int64_t expectedShape = sourceShape[dim] + low + high; + if (expectedShape != resultShape[dim]) { + return op.emitError("expected output shape to be (") + << low << " + " << sourceShape[dim] << " + " << high + << ") = " << expectedShape; + } + } + + auto ®ion = op.region(); + if (!llvm::hasSingleElement(region)) + return op.emitOpError("expected region with 1 block"); + if (failed(BlockArgsVerifier::verify(op, region.front()))) + return failure(); + + return success(); +} + +Value PadOp::getViewSource() { return source(); } + +static ParseResult parseListOfOperands(OpAsmParser &parser, + OperationState &result, + StringRef attrName) { + if (failed(parser.parseLSquare())) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + IntegerAttr attr; + if (failed(parser.parseAttribute(attr))) + return parser.emitError(parser.getNameLoc()) << "expected integer"; + attrVals.push_back(attr.getInt()); + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + auto arrayAttr = parser.getBuilder().getI64TensorAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +static ParseResult parsePadOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType baseInfo; + SmallVector operands; + SmallVector types; + if (parser.parseOperand(baseInfo)) + return failure(); + + if (parseListOfOperands(parser, result, "lowPadding") || + parseListOfOperands(parser, result, "highPadding")) + return failure(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + SmallVector operandTypes, regionTypes; + if (parser.parseRegion(*region, regionOperands, regionTypes)) + return failure(); + result.addRegion(std::move(region)); + + Type srcType, dstType; + if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType)) + return failure(); + + if (parser.resolveOperand(baseInfo, srcType, result.operands)) + return failure(); + + return parser.addTypeToList(dstType, result.types); +} + +static void printListOfOperands(OpAsmPrinter &p, DenseIntElementsAttr attr) { + p << '['; + llvm::interleaveComma(attr, p, [&](APInt a) { p << a.getSExtValue(); }); + p << ']'; +} + +static void print(OpAsmPrinter &p, PadOp op) { + p << op->getName().getStringRef() << ' '; + p << op.source(); + printListOfOperands(p, op.lowPadding()); + p << ' '; + printListOfOperands(p, op.highPadding()); + p.printRegion(op.region()); + p << " : " << op.source().getType() << " to " << op.getType(); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1146,6 +1292,12 @@ if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(op, cast(parentOp)); + if (auto padOp = dyn_cast(parentOp)) { + return success(op.getNumOperands() == 1 && + op.getOperand(0).getType() == + padOp.getType().cast().getElementType()); + } + return op.emitOpError("expected parent op with LinalgOp interface"); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -430,3 +430,58 @@ -> tensor return } + +// ----- + +func @pad_mismatch_rank(%arg0: memref, %arg1: f32) -> memref { + // expected-error @+1 {{expected size of lowPadding to match rank of source}} + %0 = linalg.padded_view %arg0[1, 2] [2, 2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1 : f32 + } : memref to memref + return %0 : memref +} + +// ----- + +func @pad_mismatch_rank(%arg0: memref, %arg1: f32) -> memref { + // expected-error @+1 {{expected size of highPadding to match rank of source}} + %0 = linalg.padded_view %arg0[2, 1, 2] [2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1 : f32 + } : memref to memref + return %0 : memref +} + +// ----- + +func @pad_result_type(%arg0: memref, %arg1: i32) -> memref { + // expected-error @+1 {{expected output shape to be (2 + 4 + 3) = 9}} + %0 = linalg.padded_view %arg0[1, 2] [2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1 : i32 + } : memref to memref + return %0 : memref +} + +// ----- + +func @pad_block_args(%arg0: memref, %arg1: i32) -> memref { + // expected-error @+1 {{op expected block argument 1 to be an index}} + %0 = linalg.padded_view %arg0[1, 2] [2, 3] { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + linalg.yield %arg1 : i32 + } : memref to memref + return %0 : memref +} + +// ----- + +func @pad_mismatch_rank(%arg0: memref, %arg1: i32) -> memref { + // expected-error @+1 {{expected result rank to match the rank of the source}} + %0 = linalg.padded_view %arg0[1, 2] [2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1 : i32 + } : memref to memref + return %0 : memref +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -6,6 +6,32 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +func @pad_dynamic(%arg0: memref, %pad_value: f32) -> memref { + %0 = linalg.padded_view %arg0[1, 2] [2, 3] { + ^bb0(%arg1 : index, %arg2 : index): + linalg.yield %pad_value : f32 + } : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @pad_dynamic +// CHECK: linalg.padded_view %{{.*}}[1, 2] [2, 3] +// CHECK: : memref to memref + +// ----- + +func @pad_static(%arg0: memref<3x4xf32>, %pad_value: f32) -> memref<6x9xf32> { + %0 = linalg.padded_view %arg0[1, 2] [2, 3] { + ^bb0(%arg1 : index, %arg2 : index): + linalg.yield %pad_value : f32 + } : memref<3x4xf32> to memref<6x9xf32> + return %0 : memref<6x9xf32> +} +// CHECK-LABEL: func @pad_static +// CHECK: linalg.padded_view %{{.*}}[1, 2] [2, 3] +// CHECK: : memref<3x4xf32> to memref<6x9xf32> + +// ----- + func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return