diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -396,23 +396,31 @@ def TileOp : Op, - FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + DeclareOpInterfaceMethods]> { let description = [{ - Indicates that the given `target` op should be tiled with the options - provided as attributes. This transform generates a loop nest with a smaller - ("tiled") target operation in its body. Currently limited to LinalgOps. - - `sizes` are the tile sizes. A tile size of `0` indicates that the - respective dimension should not be tiled. No loop will be generated for such - dimensions. If all tile sizes are `0`, this transform is effectively a - no-op. + Indicates that the given `target` op should be tiled with the given sizes. + This transform generates a loop nest with a smaller ("tiled") target + operation in its body. Currently limited to LinalgOps. + + Tile sizes may be known at transformation time, in which case they are + expected to be provided in the `static_size` attribute, or not, in which + case the tile value must be computed by the payload IR and the handle to the + operation computing it must be provided through `dynamic_sizes`. When the + sizes are not known statically, the corresponding entry in the + `static_sizes` attribute must be set to `ShapedType::kDynamicSize`. Only + the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should + be as many handles as `ShapedType::kDynamicSize` values in the + `static_sizes` attribute. A static size of `0` indicates that the dimension + should not be tiled. No loop will be generated for such dimensions. If all + tile sizes are `0`, this transform is effectively a no-op. This op returns handles to the tiled op (in the generated loop nest) and the generated loops. The number of loops is the number of non-zero tile sizes. }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$sizes, + Variadic:$dynamic_sizes, + DefaultValuedAttr:$static_sizes, DefaultValuedAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -210,6 +210,11 @@ LinalgOp op, ValueRange operands, ValueRange results); +/// Turns an OpFoldResult into a value, creating an index-typed constant if +/// necessary. +Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult); + /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -103,16 +104,10 @@ /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult -applyTilingToAll(Operation *transformOp, Value target, - ArrayRef tileSizes, +applyTilingToAll(Operation *transformOp, ArrayRef payloadOps, + unsigned numLoops, transform::TransformResults &transformResults, - transform::TransformState &state, function_ref(LinalgOp)> applyFn) { - // Number of loops: Number of tiles sizes that are not zero. - size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); - // All payload ops. These should all be LinalgOps for now. - ArrayRef payloadOps = state.getPayloadOps(target); - SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) @@ -178,8 +173,9 @@ fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); LogicalResult result = applyTilingToAll( - getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, - state, [&](LinalgOp linalgOp) -> FailureOr { + getOperation(), state.getPayloadOps(getTarget()), + fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), + transformResults, [&](LinalgOp linalgOp) -> FailureOr { LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(linalgOp); @@ -194,8 +190,7 @@ tileLoopNest->getLoopOps().end()}; return tiledLinalgOp; }); - return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() - : DiagnosedSilenceableFailure::success(); + return DiagnosedSilenceableFailure(result); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, @@ -603,32 +598,124 @@ transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; - SmallVector tileSizes = extractI64Array(getSizes()); + SmallVector tileSizes = extractI64Array(getStaticSizes()); + + ArrayRef targets = state.getPayloadOps(getTarget()); + SmallVector> dynamicSizeProducers; + dynamicSizeProducers.reserve(getDynamicSizes().size()); + for (Value dynamicSizeProducerHandle : getDynamicSizes()) { + dynamicSizeProducers.push_back( + state.getPayloadOps(dynamicSizeProducerHandle)); + + if (dynamicSizeProducers.back().size() != targets.size()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected as many dynamic size-producing operations (" + << dynamicSizeProducers.back().size() << ") as target ops (" + << targets.size() << ")"; + diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + return diag; + } - if (!tileSizes.empty()) - tilingOptions.setTileSizes(tileSizes); - tilingOptions.setInterchange(extractUIntArray(getInterchange())); - LinalgTilingPattern pattern(getContext(), tilingOptions); + for (Operation *op : dynamicSizeProducers.back()) { + if (op->getNumResults() == 1 && + op->getResult(0).getType().isa()) + continue; + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected sizes to be produced by ops " + "with a single index-type result"; + diag.attachNote(op->getLoc()) << "size producer op"; + diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + return diag; + } + } - LogicalResult result = applyTilingToAll( - getOperation(), getTarget(), tileSizes, transformResults, state, - [&](LinalgOp linalgOp) { - SimpleRewriter rewriter(linalgOp.getContext()); - return pattern.returningMatchAndRewrite(linalgOp, rewriter); + unsigned numLoops = tileSizes.size() - llvm::count(tileSizes, 0); + SmallVector tiled; + SmallVector, 4> loops; + loops.resize(numLoops); + for (auto &en : llvm::enumerate(targets)) { + auto linalgOp = dyn_cast(en.value()); + if (!linalgOp) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "only linalg ops are supported"; + diag.attachNote(en.value()->getLoc()) << "target op"; + return diag; + } + + unsigned index = en.index(); + if (!tileSizes.empty()) { + tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b, + Operation *) { + SmallVector sizes; + sizes.reserve(tileSizes.size()); + unsigned dynamicIdx = 0; + for (int64_t size : tileSizes) { + if (size == ShapedType::kDynamicSize) { + sizes.push_back( + dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); + } else { + sizes.push_back(b.create(getLoc(), size)); + } + } + return sizes; }); - return DiagnosedSilenceableFailure(result); + } + + tilingOptions.setInterchange(extractUIntArray(getInterchange())); + LinalgTilingPattern pattern(getContext(), tilingOptions); + SimpleRewriter rewriter(linalgOp.getContext()); + FailureOr tiledOp = + pattern.returningMatchAndRewrite(linalgOp, rewriter); + if (failed(tiledOp)) + return DiagnosedSilenceableFailure::definiteFailure(); + + tiled.push_back(tiledOp->op); + for (const auto &en2 : llvm::enumerate(tiledOp->loops)) + loops[en2.index()].push_back(en2.value()); + } + + transformResults.set(getTiledLinalgOp().cast(), tiled); + for (const auto &en : llvm::enumerate(loops)) + transformResults.set(getLoops()[en.index()].cast(), en.value()); + + return DiagnosedSilenceableFailure::success(); } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTileLikeOp(parser, result, - TileOp::getSizesAttrName(result.name).getValue()); + OpAsmParser::UnresolvedOperand target; + SmallVector dynamicSizes; + ArrayAttr staticSizes; + auto pdlOperationType = pdl::OperationType::get(parser.getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || + parser.parseOptionalAttrDict(result.attributes)) + return ParseResult::failure(); + + result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); + size_t numExpectedLoops = + staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0); + result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); + return success(); } void TileOp::print(OpAsmPrinter &p) { - p << ' '; - p << getTarget(); - p.printOptionalAttrDict((*this)->getAttrs()); + p << ' ' << getTarget(); + printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), + getStaticSizes()); + p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); +} + +void transform::TileOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + onlyReadsHandle(getDynamicSizes(), effects); + producesHandle(getTiledLinalgOp(), effects); + producesHandle(getLoops(), effects); + modifiesPayload(effects); } //===----------------------------------------------------------------------===// @@ -678,6 +765,7 @@ LinalgTransformDialectExtension> { public: LinalgTransformDialectExtension() { + declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -15,16 +15,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Turns an OpFoldResult into a value, creating an index-typed constant if -/// necessary. -static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, - OpFoldResult opFoldResult) { - if (opFoldResult.is()) - return opFoldResult.get(); - auto attr = opFoldResult.get().cast(); - return builder.create(attr.getValue().getSExtValue()); -} - /// Extract the slices of `operands` supplied to the given operation `op` such /// that they are sufficient to execute the op for the subset of its iteration /// space defined by `splitIterationSpace`. The subset is a part of the original diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -993,6 +993,14 @@ return tensorResults; } +Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult) { + if (auto value = opFoldResult.dyn_cast()) + return value; + auto attr = opFoldResult.get().cast(); + return builder.create(attr.getValue().getSExtValue()); +} + SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef valuesToTile, diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -191,18 +191,40 @@ def __init__(self, target: Union[Operation, Value], *, - sizes: OptionalIntList = None, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, interchange: OptionalIntList = None, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - sizes_attr = _get_int_array_attr(sizes) + i64_type = IntegerType.get_signless(64) + + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(IntegerAttr.get(i64_type, size)) + elif isinstance(size, IntegerAttr): + static_sizes.append(size) + else: + static_sizes.append( + IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = ArrayAttr.get(static_sizes) + num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) super().__init__( pdl_operation_type, [pdl_operation_type] * num_loops, _get_op_result_or_value(target), - sizes=sizes_attr, + dynamic_sizes=dynamic_sizes, + static_sizes=sizes_attr, interchange=_get_int_array_attr(interchange) if interchange else None, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -23,7 +23,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]} + %1, %loops = transform.structured.tile %0 [10, 0, 0] %2 = transform.structured.scalarize %1 } } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -1,11 +1,11 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): sequence %arg0 { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]} + %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] } pdl.pattern @pdl_target : benefit(1) { @@ -44,3 +44,58 @@ return %0 : tensor<128x128xf32> } +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = pdl_match @func_call in %arg1 + %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } + pdl.pattern @func_call : benefit(1) { + %args = operands + %results = types + %0 = operation "func.call"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +func.func private @get_dynamic_tile_size() -> index + +// CHECK-LABEL: func @tile_linalg_matmul_dynamic( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: -> tensor<128x128xf32> { +func.func @tile_linalg_matmul_dynamic( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) { +// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor +// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x?xf32> +// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor +// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor, tensor<4x?xf32>) +// CHECK-SAME: outs(%[[sTC]] : tensor) -> tensor +// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor<128x128xf32> +// CHECK: scf.yield %[[TD]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32> + %sz = func.call @get_dynamic_tile_size() : () -> index + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + +// CHECK: return %[[TD0]] : tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -3,7 +3,7 @@ transform.sequence { ^bb1(%arg0: !pdl.operation): // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile - %0, %1:2 = transform.structured.tile %arg0 { sizes = [2, 0, 3] } + %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir --- a/mlir/test/Dialect/Transform/selective-targeting.mlir +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -77,7 +77,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target_attrA in %arg1 - transform.structured.tile %0 {sizes = [4, 4, 4]} + transform.structured.tile %0 [4, 4, 4] %1 = pdl_match @pdl_target_attrC in %arg1 %2 = transform.get_closest_isolated_parent %1 transform.structured.vectorize %2 diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -105,9 +105,8 @@ transform.YieldOp() # CHECK-LABEL: TEST: testTileCompact # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile - # CHECK-DAG: interchange = [0, 1] - # CHECK-DAG: sizes = [4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] @run @@ -122,9 +121,8 @@ transform.YieldOp() # CHECK-LABEL: TEST: testTileAttributes # CHECK: transform.sequence - # CHECK: structured.tile - # CHECK-DAG: interchange = [0, 1] - # CHECK-DAG: sizes = [4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] @run @@ -136,9 +134,24 @@ transform.YieldOp() # CHECK-LABEL: TEST: testTileZero # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile - # CHECK-DAG: interchange = [0, 1, 2, 3] - # CHECK-DAG: sizes = [4, 0, 2, 0] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] + # CHECK: interchange = [0, 1, 2, 3] + + +@run +def testTileDynamic(): + with_pdl = transform.WithPDLPatternsOp() + with InsertionPoint(with_pdl.body): + sequence = transform.SequenceOp(with_pdl.bodyTarget) + with InsertionPoint(sequence.body): + m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") + m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") + structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileDynamic + # CHECK: %[[FIRST:.+]] = pdl_match + # CHECK: %[[SECOND:.+]] = pdl_match + # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] @run diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7448,6 +7448,7 @@ ], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":LinalgDialect", ":LinalgTransformOpsIncGen",