diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -376,8 +376,10 @@ using TileSizeComputationFunction = std::function(OpBuilder &, Operation *)>; +/// Specify the padding value for an OpOperand. This should be a function of +/// both the operation and the operand type. using PaddingValueComputationFunction = - std::function; + std::function; struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1373,10 +1373,13 @@ return verifyYield(op, cast(parentOp)); if (auto padTensorOp = dyn_cast(parentOp)) { - return success( - op.getNumOperands() == 1 && - op.getOperand(0).getType() == - padTensorOp.getType().cast().getElementType()); + if (op.getNumOperands() != 1) + return op.emitOpError("expected single yield operand (got ") + << op->getNumOperands() << ")"; + if (op.getOperand(0).getType() != + padTensorOp.getType().cast().getElementType()) + return op.emitOpError("expected yield type to match shape element type"); + return success(); } return op.emitOpError("expected parent op with LinalgOp interface"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -127,13 +127,13 @@ /// created PadTensorOp. /// Return failure if the operand cannot be padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( - PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, + PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand, const LinalgTilingOptions &options, Value &result) { - auto tensorType = operand.getType().cast(); + auto tensorType = operand.get().getType().cast(); // Already static shape, no need to pad. if (tensorType.hasStaticShape()) return success(); - auto subtensor = operand.getDefiningOp(); + auto subtensor = operand.get().getDefiningOp(); // Not a subtensor, cannot construct a static bounding box. if (!subtensor) return failure(); @@ -152,11 +152,11 @@ opToPad, "No constant bounding box can be found for padding"); staticSizes.push_back(indexAttr.getInt()); } - Value pad = options.paddingValueComputationFunction(rewriter, opToPad); + Value pad = options.paddingValueComputationFunction(rewriter, operand); auto staticTensorType = RankedTensorType::get(staticSizes, tensorType.getElementType()); - result = linalg::PadTensorOp::createPadHighOp(staticTensorType, operand, pad, - opToPad->getLoc(), rewriter); + result = linalg::PadTensorOp::createPadHighOp( + staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter); return success(); } @@ -180,26 +180,26 @@ // Set IP after op because we also take the dims of the original output. rewriter.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. - SmallVector operands = opToPad.getShapedOperands(); - for (Value &v : operands) { + SmallVector newOperands; + newOperands.reserve(opToPad.getNumShapedOperands()); + for (OpOperand &operand : opToPad.getShapedOpOperands()) { Value paddedOperand; // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. - if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, + if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand, options, paddedOperand))) { return failure(); } - // Update v if we indeed got a padded operand. - v = paddedOperand ? paddedOperand : v; + newOperands.push_back(paddedOperand ? paddedOperand : operand.get()); } // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = - ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); + ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); - operands.append(otherOperands.begin(), otherOperands.end()); + newOperands.append(otherOperands.begin(), otherOperands.end()); linalg::LinalgOp paddedOp = - opToPad.clone(rewriter, loc, resultTensorTypes, operands); + opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); // Recover the subtensor out of the new static results. This keeps the // original linalg op around because it uses the dims of the original results. diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -646,6 +646,28 @@ // ----- +func @pad_num_yields(%arg0: tensor, %arg1: i32) -> tensor { + // expected-error @+3 {{op expected single yield operand (got 2)}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1, %arg1 : i32, i32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + +func @pad_yield_type(%arg0: tensor, %arg1: i8) -> tensor { + // expected-error @+3 {{op expected yield type to match shape element type}} + %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): // no predecessors + linalg.yield %arg1 : i8 + } : tensor to tensor + return %0 : tensor +} + +// ----- + func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) { %0 = linalg.init_tensor [%arg0, %arg1] : tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,41 +1,41 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s // CHECK-LABEL: func @matmul_tensors( -// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor -// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor -// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { func @matmul_tensors( - %arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { -// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { -// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { -// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { -// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor -// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor to tensor -// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor to tensor +// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor // Dynamic op has been canonicalized away. -// CHECK-NOT: linalg.matmul {{.*}} tensor +// CHECK-NOT: linalg.matmul {{.*}} tensor // Padding injects static information. // CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] -// CHECK: : tensor to tensor<2x4xf32> +// CHECK: : tensor to tensor<2x4xi8> // CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] -// CHECK: : tensor to tensor<4x3xf32> +// CHECK: : tensor to tensor<4x3xi8> // CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] -// CHECK: : tensor to tensor<2x3xf32> -// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>) -// CHECK-SAME: outs(%[[pC]] : tensor<2x3xf32>) -> tensor<2x3xf32> -// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor -// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor -// CHECK: scf.yield %[[TD]] : tensor -// CHECK: scf.yield %[[TD2]] : tensor -// CHECK: scf.yield %[[TD1]] : tensor - %0 = linalg.matmul {__internal_linalg_transform__ = "tile-and-pad"} - ins(%arg0, %arg1: tensor, tensor) - outs(%arg2: tensor) - -> tensor +// CHECK: : tensor to tensor<2x3xi32> +// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) +// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xi32> to tensor +// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD]] : tensor +// CHECK: scf.yield %[[TD2]] : tensor +// CHECK: scf.yield %[[TD1]] : tensor + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor -// CHECK: return %[[TD0]] : tensor - return %0 : tensor +// CHECK: return %[[TD0]] : tensor + return %0 : tensor } diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -508,9 +508,9 @@ // For now, just assume it is the zero of type. // In the future, it should be the zero of type + op. -static Value getNeutralOfLinalgOp(OpBuilder &b, Operation *op) { - auto t = op->getResult(0).getType().cast().getElementType(); - return b.create(op->getLoc(), t, b.getZeroAttr(t)); +static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { + auto t = getElementTypeOrSelf(op.get().getType()); + return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } static void applyTileAndPadPattern(FuncOp funcOp) { @@ -520,7 +520,7 @@ linalg::LinalgTilingOptions() .setTileSizes({2, 3, 4}) .setPaddingValueComputationFunction(getNeutralOfLinalgOp); - tilingPattern.insert>( + tilingPattern.insert>( context, linalgTilingOptions, linalg::LinalgTransformationFilter( Identifier::get("tile-and-pad", context)));