diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -146,6 +146,7 @@ dimension, i.e `low`. * high: A list contains the padding along the end of each dimension, i.e. `high`. + * output: An optional output operand. The result tensor dimensions are `low` + `dim` + `high` along that dimension. The number of elements of `low` and `high` must match @@ -194,16 +195,21 @@ Variadic:$low, Variadic:$high, I64ArrayAttr:$static_low, - I64ArrayAttr:$static_high); + I64ArrayAttr:$static_high, + Optional:$output); let regions = (region SizedRegion<1>:$region); let results = (outs AnyTensor:$result); + // TODO: Remove custom when AllTypesMatch supports opt. operands. let assemblyFormat = [{ - $source `low` `` custom($low, $static_low) + $source + `low` `` custom($low, $static_low) `high` `` custom($high, $static_high) + (`into` $output^ )? $region attr-dict `:` type($source) `to` type($result) + custom(ref($output), type($output), ref(type($result))) }]; let extraClassDeclaration = [{ @@ -292,7 +298,12 @@ // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "Type":$resultType, "Value":$source, "ArrayRef":$low, "ArrayRef":$high, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"ArrayRef", "{}">:$attrs)>, + // Build a PadTensorOp with mixed static and dynamic entries and custom + // result type. + OpBuilder<(ins "Type":$resultType, "Value":$source, + "ArrayRef":$low, "ArrayRef":$high, "ArrayAttr":$staticLow, + "ArrayAttr":$staticHigh)> ]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -855,6 +855,19 @@ // PadTensorOp //===----------------------------------------------------------------------===// +// TODO: Replace custom directive with AllTypesMatch as soon as it +// supports optional types. +void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, + Type typeToInfer, Type typeToInferFrom) {} + +ParseResult parseInferType(OpAsmParser &parser, + Optional optOperand, + Type &typeToInfer, Type typeToInferFrom) { + if (optOperand) + typeToInfer = typeToInferFrom; + return success(); +} + static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); @@ -870,6 +883,9 @@ << resultType << " does not match the inferred type " << expectedType; } + if (op.output() && op.output().getType() != op.getResultType()) { + op.emitError("expected that output operand type equals result type"); + } auto ®ion = op.region(); unsigned rank = resultType.getRank(); @@ -916,7 +932,7 @@ auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), - b.getI64ArrayAttr(staticHigh)); + b.getI64ArrayAttr(staticHigh), /*output=*/Value()); result.addAttributes(attrs); } @@ -953,7 +969,15 @@ PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); } build(b, result, resultType, source, dynamicLow, dynamicHigh, - b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); + b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), + /*output=*/Value()); +} + +void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value source, ArrayRef low, ArrayRef high, + ArrayAttr staticLow, ArrayAttr staticHigh) { + build(b, result, resultType, source, low, high, staticLow, staticHigh, + /*output=*/{}); } PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, @@ -1038,11 +1062,25 @@ } }; +// Fold tensor.dim(pad_tensor(%input, %output)) to tensor.dim(%output). +struct FoldToDimOfOutputOperand : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const override { + auto padTensorOp = dimOp.source().getDefiningOp(); + if (!padTensorOp || !padTensorOp.output()) + return failure(); + rewriter.replaceOpWithNewOp(dimOp, padTensorOp.output(), + dimOp.index()); + return success(); + } +}; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -902,3 +902,21 @@ %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> return %r: tensor<2xf32> } + +// ----- + +// CHECK-LABEL: func @dim_of_pad_tensor( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK: return %[[RESULT]] +func @dim_of_pad_tensor(%arg0: tensor, %arg1: tensor, + %pad_value: f32) -> index { + %c0 = constant 0 : index + %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor + %r = tensor.dim %0, %c0 : tensor + return %r : index +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -584,6 +584,18 @@ // ----- +// expected-note@+1 {{prior use here}} +func @pad_output_type(%arg0: tensor, %arg1: index, %arg2: i32, %output: tensor) -> tensor { + // expected-error @+1 {{use of value '%output' expects different type than prior uses: 'tensor' vs 'tensor'}} + %0 = linalg.pad_tensor %arg0 low[1, 1, 1, 1] high[2, 2, 2, 2] into %output { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %arg2 : i32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + func @pad_number_of_block_args(%arg0: tensor, %arg1: i32) -> tensor { // expected-error @+1 {{expected the block to have 2 arguments}} %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -51,6 +51,24 @@ // ----- +func @pad_static_with_output(%arg0: tensor<3x4xf32>, + %out_tensor : tensor<6x9xf32>, + %pad_value: f32) + -> tensor<6x9xf32> { + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] into %out_tensor { + ^bb0(%arg1 : index, %arg2 : index): + linalg.yield %pad_value : f32 + } : tensor<3x4xf32> to tensor<6x9xf32> + return %0 : tensor<6x9xf32> +} +// CHECK-LABEL: func @pad_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<3x4xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<6x9xf32>, +// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] into %[[ARG1]] +// CHECK: : tensor<3x4xf32> to tensor<6x9xf32> + +// ----- + func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index, %pad_value: f32) -> tensor { %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {