diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -146,9 +146,9 @@ * low: A list contains the padding along the start of each dimension, i.e `low`. * high: A list contains the padding along the end of each - dimension, i.e. `high`. - * packing: whether the padding operation is guaranteed to create a new - tensor suitable for packing, i.e. a copy. + dimension, i.e. `high`. + * nofold: indicates that the operation should not be folded when source and + result types are equal. The result tensor dimensions are `low` + `dim` + `high` along that dimension. The number of elements of `low` and `high` must match @@ -161,10 +161,9 @@ the rank of the `source` tensor. The value `yield`-ed by the region is used as the value of the view at the given position. - If `packing` is indicated, the padding is guaranteed to produce a new - tensor, e.g., to use for packing or promotion to faster memory. Such - operations are not optimized away even when the source type has the same - static shape. + If `nofold` is set, the padding operation will not be folded away even + if the source type and the padded type have the same static shape. This can + be used, e.g., for packing or promotion to faster memory. Example 1: @@ -199,9 +198,9 @@ Example 4: ```mlir - // Force a padded value to be always exist with `packing`. + // Force a padded value to be always exist with `nofold`. %pad_value = ... : f32 - %0 = linalg.pad_tensor %arg0 packing low[0, 0] high[0, 0] { + %0 = linalg.pad_tensor %arg0 nofold low[0, 0] high[0, 0] { ^bb0(%arg1: index, %arg2: index): linalg.yield %pad_value : f32 } : tensor<2x3xf32> to tensor<2x3xf32> @@ -214,7 +213,7 @@ Variadic:$high, I64ArrayAttr:$static_low, I64ArrayAttr:$static_high, - UnitAttr:$packing); + UnitAttr:$nofold); let regions = (region SizedRegion<1>:$region); @@ -223,7 +222,7 @@ // TODO: Remove custom when AllTypesMatch supports opt. operands. let assemblyFormat = [{ $source - (`packing` $packing^)? + (`nofold` $nofold^)? `low` `` custom($low, $static_low) `high` `` custom($high, $static_high) $region attr-dict `:` type($source) `to` type($result) @@ -260,7 +259,7 @@ // "high" padding (i.e. it adds trailing padding values until the desired // size is met). static linalg::PadTensorOp createPadHighOp( - Type type, Value source, Value pad, bool packing, Location loc, + Type type, Value source, Value pad, bool nofold, Location loc, OpBuilder & builder); // Return a PadTensorOp that pads `source to `type` size with `pad` value. @@ -268,7 +267,7 @@ // directly. If the type passed is nullptr, it is inferred. static linalg::PadTensorOp createPadScalarOp( Type type, Value source, Value pad, ArrayRef low, - ArrayRef high, bool packing, Location loc, + ArrayRef high, bool nofold, Location loc, OpBuilder & builder); // Return the pad value if it is a constant. Return null value otherwise. @@ -313,17 +312,17 @@ // Build a PadTensorOp with mixed static and dynamic entries. OpBuilder<(ins "Value":$source, "ArrayRef":$staticLow, "ArrayRef":$staticHigh, "ValueRange":$low, "ValueRange":$high, - CArg<"bool", "false">:$packing, + CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, // Build a PadTensorOp with all dynamic entries. OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, - CArg<"bool", "false">:$packing, + CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, // Build a PadTensorOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "Type":$resultType, "Value":$source, "ArrayRef":$low, "ArrayRef":$high, - CArg<"bool", "false">:$packing, + CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, ]; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -87,7 +87,7 @@ return linalg::PadTensorOp::createPadScalarOp( RankedTensorType::get(paddedShape, inputETy), input, padValue, - lowIndices, highIndices, /*packing=*/false, loc, rewriter) + lowIndices, highIndices, /*nofold=*/false, loc, rewriter) .result(); } @@ -2349,7 +2349,7 @@ auto newPadOp = linalg::PadTensorOp::createPadScalarOp( padOp.getType(), input, constant, lowValues, highValues, - /*packing=*/false, loc, rewriter); + /*nofold=*/false, loc, rewriter); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1085,28 +1085,28 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, - ValueRange high, bool packing, + ValueRange high, bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), - b.getI64ArrayAttr(staticHigh), packing ? b.getUnitAttr() : UnitAttr()); + b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, - ValueRange low, ValueRange high, bool packing, + ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamicSize); - build(b, result, source, staticVector, staticVector, low, high, packing, + build(b, result, source, staticVector, staticVector, low, high, nofold, attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ArrayRef low, - ArrayRef high, bool packing, + ArrayRef high, bool nofold, ArrayRef attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); @@ -1129,17 +1129,17 @@ } build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), - packing ? b.getUnitAttr() : UnitAttr()); + nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, ArrayRef low, ArrayRef high, - bool packing, Location loc, + bool nofold, Location loc, OpBuilder &builder) { - auto padTensorOp = builder.create(loc, type, source, low, - high, packing); + auto padTensorOp = + builder.create(loc, type, source, low, high, nofold); int rank = padTensorOp.getResultType().getRank(); SmallVector blockArgTypes; blockArgTypes.assign(rank, builder.getIndexType()); @@ -1153,7 +1153,7 @@ } PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, - bool packing, Location loc, + bool nofold, Location loc, OpBuilder &builder) { SmallVector low, high; auto rankedTensorType = type.cast(); @@ -1167,7 +1167,7 @@ high.push_back(highValue); low.push_back(builder.createOrFold(loc, 0)); } - return PadTensorOp::createPadScalarOp(type, source, pad, low, high, packing, + return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold, loc, builder); } @@ -1440,8 +1440,8 @@ } namespace { -// Folds linalg.pad_tensor when padding is static zeros and packing is not -// requested. +// Folds linalg.pad_tensor when padding is static zeros and the attribute +// doesn't request otherwise. struct FoldStaticZeroPadding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1449,7 +1449,7 @@ PatternRewriter &rewriter) const override { if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) return failure(); - if (padTensorOp.packing()) + if (padTensorOp.nofold()) return failure(); rewriter.replaceOpWithNewOp( padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); @@ -1481,7 +1481,7 @@ auto newOp = rewriter.create( padTensorOp->getLoc(), newResultType, padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), - padTensorOp.static_high(), padTensorOp.packing()); + padTensorOp.static_high(), padTensorOp.nofold()); BlockAndValueMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); @@ -1513,7 +1513,7 @@ padTensorOp.getLoc(), tensorCastOp.dest().getType(), padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), padTensorOp.static_high(), - padTensorOp.packing()); + padTensorOp.nofold()); replacementOp.region().takeBody(padTensorOp.region()); rewriter.replaceOp(padTensorOp, replacementOp.result()); @@ -1555,7 +1555,7 @@ OpFoldResult PadTensorOp::fold(ArrayRef) { if (getResultType().hasStaticShape() && getResultType() == getSourceType() && - !packing()) + !nofold()) return source(); return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -182,7 +182,7 @@ staticSizes, getElementTypeOrSelf(opOperand->get())); result = linalg::PadTensorOp::createPadHighOp( staticTensorType, opOperand->get(), paddingValue.getValue(), - /*packing=*/true, opToPad->getLoc(), rewriter); + /*nofold=*/true, opToPad->getLoc(), rewriter); return success(); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -630,14 +630,14 @@ // ----- -// CHECK-LABEL: func @pad_tensor_packing_same_static_shape( +// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK: %[[PAD:.*]] = linalg.pad_tensor // CHECK: return %[[PAD]] -func @pad_tensor_packing_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = constant 0.000000e+00 : f32 - %0 = linalg.pad_tensor %arg0 packing low[%a, 0] high[0, %a] { + %0 = linalg.pad_tensor %arg0 nofold low[%a, 0] high[0, %a] { ^bb0(%arg1: index, %arg2: index): linalg.yield %cst : f32 } : tensor<5x6xf32> to tensor<5x6xf32> @@ -937,13 +937,13 @@ // ----- -// CHECK-LABEL: func @pad_packing_static_zero( +// CHECK-LABEL: func @pad_nofold_static_zero( // CHECK-SAME: %[[ARG0:.*]]: tensor // CHECK: %[[PAD:.*]] = linalg.pad_tensor // CHECK: return %[[PAD]] -func @pad_packing_static_zero(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { +func @pad_nofold_static_zero(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { %c0 = constant 0 : index - %0 = linalg.pad_tensor %arg0 packing low[0, %c0, 0] high[0, 0, %c0] { + %0 = linalg.pad_tensor %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] { ^bb0(%arg1: index, %arg2: index, %arg3: index): linalg.yield %pad_value : f32 } : tensor to tensor<2x3x4xf32> diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -20,11 +20,11 @@ // CHECK-NOT: linalg.matmul {{.*}} tensor // Padding injects static information. -// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x4xi8> -// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<4x3xi8> -// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3xi32> // CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) // CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> @@ -55,7 +55,7 @@ // CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor to tensor // Padding injects static information. -// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3x4xf32> // CHECK: %[[pD:.*]] = linalg.generic // CHECK-SAME: ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>) @@ -108,9 +108,9 @@ // CHECK-1DIM-TILE: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor to tensor // CHECK-1DIM-TILE: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8> // CHECK-1DIM-TILE: %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor to tensor -// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> -// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8> // CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) // CHECK-1DIM-TILE: outs(%[[sTC]] : tensor) -> tensor @@ -122,7 +122,7 @@ func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x3x4xf32> { // CHECK: %[[c0:.*]] = constant 0 : index // CHECK-NOT: scf.for - // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]] + // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]] // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32> %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> ()>, @@ -140,7 +140,7 @@ func @pad_static_divisible_size(%arg0: tensor<4x6x8xf32>, %arg1: f32) -> tensor<4x6x8xf32> { // CHECK: %[[c0:.*]] = constant 0 : index // CHECK-COUNT-3: scf.for - // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]] + // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]] // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32> %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> ()>,