diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -380,6 +380,36 @@ let hasFolder = 1; } +def Linalg_SimplePadOp : Linalg_Op<"simple_pad", [NoSideEffect]> { + let summary = "TODO: replace with pad_tensors when ready."; + + let description = [{ + `linalg.simple_pad` is a tmp placeholder for padding and packing on tensors. + Its semantics are to pad a partially dynamic tensor to a fully static tensor + where the static sizes are assumed to be greater than the dynamic sizes. The + op perforrms "high" padding (i.e. it adds trailing padding values until the + desired size is met). + }]; + + let arguments = (ins AnyRankedTensor:$tensor, AnyType:$padding); + let results = (outs AnyRankedTensor:$result); + + // TODO: verify all static result, some dynamic input, static shapes match, + // element types match, ranks match etc. Use pad_tensors when ready. + + let extraClassDeclaration = [{ + RankedTensorType getSourceType() { + return tensor().getType().cast(); } + RankedTensorType getResultType() { + return getResult().getType().cast(); } + }]; + + let assemblyFormat = [{ + $tensor `pad` $padding attr-dict `:` + type($tensor) `to` type($result) `pad` type($padding) + }]; +} + def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -345,6 +345,9 @@ using TileSizeComputationFunction = std::function(OpBuilder &, Operation *)>; +using PaddingValueComputationFunction = + std::function; + struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. /// Delayed construction of constant tile sizes should occur to interoperate @@ -393,6 +396,18 @@ distribution = std::move(distributionOptions); return *this; } + + /// Computation function that returns a padding value to use when padding to + /// force static sizes. When `paddingValueComputationFunction` is set, padding + /// operations are introduced, that guarantee the underlying op is statically + /// shaped and can thus be vectorized. + PaddingValueComputationFunction paddingValueComputationFunction = nullptr; + + LinalgTilingOptions & + setPaddingValueComputationFunction(PaddingValueComputationFunction fun) { + paddingValueComputationFunction = std::move(fun); + return *this; + } }; /// Canonicalization patterns relevant to apply after tiling patterns. These are @@ -403,6 +418,11 @@ void populateLinalgTilingCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx); +/// Base pattern that applied the tiling transformation specified by `options`. +/// Abort and return failure in 2 cases: +/// 1. if the tiling specification is invalid and tiling fails to occur. +/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set +/// and some operand shape cannot be bounded statically. struct LinalgBaseTilingPattern : public RewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgBaseTilingPattern(LinalgTilingOptions options, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -14,6 +14,7 @@ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -108,6 +108,28 @@ return $_op.sizes(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return a vector of all the static or dynamic sizes of the op. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getMixedSizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + std::array ranks = $_op.getArrayAttrRanks(); + unsigned numDynamic = 0; + unsigned count = ranks[getOffsetOperandGroupPosition()]; + for (unsigned idx = 0; idx < count; ++idx) { + if (isDynamicSize(idx)) + res.push_back($_op.sizes()[numDynamic++]); + else + res.push_back($_op.static_sizes()[idx]); + } + return res; + }] + >, InterfaceMethod< /*desc=*/[{ Return the dynamic stride operands. @@ -359,6 +381,9 @@ ]; let extraClassDeclaration = [{ + static unsigned getOffsetOperandGroupPosition() { return 0; } + static unsigned getSizeOperandGroupPosition() { return 1; } + static unsigned getStrideOperandGroupPosition() { return 2; } static StringRef getStaticOffsetsAttrName() { return "static_offsets"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -25,6 +25,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -105,6 +106,118 @@ return *this; } +/// Try to compute a static bounding box for `operand` +/// Return success if either: +/// 1. The operand is already statically shaped, `result` is left unchanged. +/// 2. The operand is (partially) dynamic, `result` is the result of a freshly +/// created SimplePadOp. +/// Return failure if the operand cannot be padded to a static shape. +static LogicalResult padOperandToSmallestStaticBoundingBox( + PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, + const LinalgTilingOptions &options, Value &result) { + auto tensorType = operand.getType().cast(); + // Already static shape, no need to pad. + if (tensorType.hasStaticShape()) + return success(); + auto subtensor = operand.getDefiningOp(); + // Not a subtensor, cannot construct a static bounding box. + if (!subtensor) + return failure(); + SmallVector staticSizes; + staticSizes.reserve(tensorType.getRank()); + auto shapedOp = + cast(subtensor.getOperation()); + for (auto size : shapedOp.getMixedSizes()) { + auto indexAttr = size.is() + ? size.get().dyn_cast() + : linalg::getSmallestBoundingIndex(size.get()); + // SmallestBoundingIndex must exist for all sizes. + // For now return an error if we can't find it. + if (!indexAttr) + return rewriter.notifyMatchFailure( + opToPad, "No constant bounding box can be found for padding"); + staticSizes.push_back(indexAttr.getInt()); + } + Value pad = options.paddingValueComputationFunction(rewriter, opToPad); + auto staticTensorType = + RankedTensorType::get(staticSizes, tensorType.getElementType()); + result = rewriter.create(opToPad->getLoc(), + staticTensorType, operand, pad); + return success(); +} + +// Try to create a static bounding box around each operand of `res.op`. +// If successful, `res.op` is rewritten in static form with padded operands. +// `res.op` is updated to the cloned static form of the op on success. +static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, + TiledLinalgOp &res, + const LinalgTilingOptions &options) { + LinalgOp opToPad = res.op; + Location loc = opToPad->getLoc(); + + // If the op is fully static, it does not need padding. + // TODO: there are cases where we may still want to pad to larger sizes. + if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { + return v.getType().cast().hasStaticShape(); + })) + return success(); + + OpBuilder::InsertionGuard g(rewriter); + // Set IP after op because we also take the dims of the original output. + rewriter.setInsertionPointAfter(opToPad); + // Make a copy of the shaped operands and update it. + SmallVector operands = opToPad.getShapedOperands(); + for (Value &v : operands) { + Value paddedOperand; + // If padding was requested but the shape cannot be bounded statically then + // the pattern fails to apply. + if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, + options, paddedOperand))) { + return failure(); + } + // Update v if we indeed got a padded operand. + v = paddedOperand ? paddedOperand : v; + } + + // Clone `opToPad` to operate on the statically padded shapes. + auto resultTensorTypes = + ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); + ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); + operands.append(otherOperands.begin(), otherOperands.end()); + linalg::LinalgOp paddedOp = + opToPad.clone(rewriter, loc, resultTensorTypes, operands); + + // Recover the subtensor out of the new static results. This keeps the + // original linalg op around because it uses the dims of the original results. + // This later folds away. + SmallVector paddedSubviewResults; + paddedSubviewResults.reserve(opToPad->getNumResults()); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + llvm::SetVector newUsersOfOpToPad; + for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { + auto rank = std::get<0>(it).getType().cast().getRank(); + SmallVector offsets(rank, zero); + auto sizes = llvm::to_vector<4>( + llvm::map_range(llvm::seq(0, rank), [&](unsigned d) -> Value { + auto dimOp = rewriter.create(loc, std::get<0>(it), d); + newUsersOfOpToPad.insert(dimOp); + return dimOp; + })); + SmallVector strides(rank, one); + paddedSubviewResults.push_back(rewriter.create( + loc, std::get<1>(it), offsets, sizes, strides)); + } + // Replace the transient `opToPad` locally, except for uses that we just + // created for the purpose of extracting the dims. + rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { + return !newUsersOfOpToPad.contains(opOp.getOwner()); + }); + + res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; + return success(); +} + /// Linalg base tiling pattern. mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( StringRef opName, MLIRContext *context, LinalgTilingOptions options, @@ -130,11 +243,34 @@ if (!res) return failure(); - // Return relevant information to derived pattern. - result = *res; + // Setup RAII guard to return properly. + bool succeeded = true; + LinalgOp tiledOp = res->op; + auto guard = llvm::make_scope_exit([&]() { + if (!succeeded) + return; + // Return relevant information to derived pattern. + result = *res; + // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary. + marker.replaceLinalgMarker(rewriter, tiledOp); + if (tiledOp != res->op) + marker.replaceLinalgMarker(rewriter, res->op); + }); + + // Consider padding on the fly only if the op has tensor semantics. + if (!options.paddingValueComputationFunction || + !linalgOp.hasTensorSemantics()) + return success(); + + // Try to pad on the fly by rewriting res->op as a padded op. + if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { + // Set so RAII guard does not propagate TiledLinalgOp to `result`. + succeeded = false; + return failure(); + } - // New marker if specified. - marker.replaceLinalgMarker(rewriter, res->op.getOperation()); + // Do not perform replacement of `linalgOp`, let the derived patterns + // do this as they see fit, from the resulting TiledLinalgOp. return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1411,13 +1411,20 @@ return Value{*dynExtents}; } + // The size at the given index is now known to be a dynamic size. + unsigned unsignedIndex = index.getValue().getZExtValue(); + + if (auto subtensor = dyn_cast_or_null(definingOp)) { + assert(subtensor.isDynamicSize(unsignedIndex) && + "Expected dynamic subtensor size"); + return subtensor.getDynamicSize(unsignedIndex); + } + // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); if (!memrefType) return {}; - // The size at the given index is now known to be a dynamic size of a memref. - unsigned unsignedIndex = index.getValue().getZExtValue(); if (auto alloc = dyn_cast_or_null(definingOp)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -753,3 +753,13 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> // CHECK: func @legal_collapsing_reshape_dynamic_memref // CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +// TODO: this op should disappear once pad_tensors is available. +// CHECK-LABEL: func @simple_pad +func @simple_pad(%0: tensor, %pad: f32) { +// CHECK: linalg.simple_pad %{{.+}} pad %{{.+}}: tensor to tensor<8x4x8xf32> + %1 = linalg.simple_pad %0 pad %pad: tensor to tensor<8x4x8xf32> pad f32 + return +} diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s + +// CHECK-LABEL: func @matmul_tensors( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +func @matmul_tensors( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor to tensor +// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor + +// Dynamic op has been canonicalized away. +// CHECK-NOT: linalg.matmul {{.*}} tensor + +// Padding injects static information. +// CHECK: %[[pA:.*]] = linalg.simple_pad %[[sTA]] pad %{{.*}} : tensor to tensor<2x4xf32> pad f32 +// CHECK: %[[pB:.*]] = linalg.simple_pad %[[sTB]] pad %{{.*}} : tensor to tensor<4x3xf32> pad f32 +// CHECK: %[[pC:.*]] = linalg.simple_pad %[[sTC]] pad %{{.*}} : tensor to tensor<2x3xf32> pad f32 +// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>) +// CHECK-SAME: outs(%[[pC]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor +// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD]] : tensor +// CHECK: scf.yield %[[TD2]] : tensor +// CHECK: scf.yield %[[TD1]] : tensor + %0 = linalg.matmul {__internal_linalg_transform__ = "tile-and-pad"} + ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor + +// CHECK: return %[[TD0]] : tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -79,6 +79,9 @@ *this, "test-affine-min-scf-canonicalization-patterns", llvm::cl::desc("Test affine-min + scf canonicalization patterns."), llvm::cl::init(false)}; + Option testTileAndPadPattern{ + *this, "test-tile-and-pad-pattern", + llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; }; } // end anonymous namespace @@ -487,6 +490,34 @@ applyOpPatternsAndFold(minOp, frozenPatterns); }); } + +// For now, just assume it is the zero of type. +// In the future, it should be the zero of type + op. +static Value getNeutralOfLinalgOp(OpBuilder &b, Operation *op) { + auto t = op->getResult(0).getType().cast().getElementType(); + return b.create(op->getLoc(), t, b.getZeroAttr(t)); +} + +static void applyTileAndPadPattern(FuncOp funcOp) { + MLIRContext *context = funcOp.getContext(); + OwningRewritePatternList tilingPattern; + auto linalgTilingOptions = + linalg::LinalgTilingOptions() + .setTileSizes({2, 3, 4}) + .setPaddingValueComputationFunction(getNeutralOfLinalgOp); + tilingPattern.insert>( + context, linalgTilingOptions, + linalg::LinalgMarker(Identifier::get("tile-and-pad", context))); + applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); + + // Explicitly walk and apply the pattern locally to avoid more general folding + // on the rest of the IR. + FrozenRewritePatternList frozenPatterns(std::move(tilingPattern)); + funcOp.walk([&frozenPatterns](AffineMinOp minOp) { + applyOpPatternsAndFold(minOp, frozenPatterns); + }); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -520,6 +551,8 @@ return applyLinalgToVectorPatterns(getFunction()); if (testAffineMinSCFCanonicalizationPatterns) return applyAffineMinSCFCanonicalizationPatterns(getFunction()); + if (testTileAndPadPattern) + return applyTileAndPadPattern(getFunction()); } namespace mlir {