diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1054,6 +1054,8 @@ let arguments = (ins PDL_Operation:$target, Variadic:$num_threads, Variadic:$tile_sizes, + Optional:$packed_num_threads, + Optional:$packed_tile_sizes, DefaultValuedOptionalAttr:$static_num_threads, DefaultValuedOptionalAttr:$static_tile_sizes, OptionalAttr:$mapping); @@ -1085,10 +1087,12 @@ let assemblyFormat = [{ $target oilist( - `num_threads` custom($num_threads, - $static_num_threads) | - `tile_sizes` custom($tile_sizes, - $static_tile_sizes)) + `num_threads` custom($packed_num_threads, + $num_threads, + $static_num_threads) | + `tile_sizes` custom($packed_tile_sizes, + $tile_sizes, + $static_tile_sizes)) (`(` `mapping` `=` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Transform/Utils/Utils.h b/mlir/include/mlir/Dialect/Transform/Utils/Utils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Utils/Utils.h @@ -0,0 +1,52 @@ +//===- Utils.h - Transform dialect utilities --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H +#define MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +class OpAsmPrinter; + +namespace transform { +class TransformState; + +/// Printer hook for custom directive in assemblyFormat. +/// +/// custom($packed, $values, $integers) +/// +/// where `values` are variadic Index values, `integers` is an `I64ArrayAttr` +/// and `packed` is a single transform dialect handle who's mapped payload ops +/// have a single Index result and represent the index list. Either `packed` +/// or the other two parameters may be specified. +/// +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list or with a single handle. E.g., `[%arg0, 7, 42, %arg42]` or just `%h`. +void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op, + Value packed, OperandRange values, + ArrayRef integers); + +/// Pasrer hook for custom directive in assemblyFormat. +/// +/// custom($packed, $values, $integers) +/// +/// See `printPackedOrDynamicIndexList` for details. +ParseResult parsePackedOrDynamicIndexList( + OpAsmParser &parser, std::optional &packed, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers); +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -18,5 +18,6 @@ MLIRSCFDialect MLIRSideEffectInterfaces MLIRTransformDialect + MLIRTransformDialectUtils MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" +#include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" @@ -1523,6 +1524,8 @@ /*target=*/target, /*num_threads=*/ValueRange{}, /*tile_sizes=*/dynamicTileSizes, + /*packed_num_threads=*/Value(), + /*packed_tile_sizes=*/Value(), /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), /*static_tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); @@ -1558,38 +1561,70 @@ /*target=*/target, /*num_threads=*/dynamicNumThreads, /*tile_sizes=*/ValueRange{}, + /*packed_num_threads=*/Value(), + /*packed_tile_sizes=*/Value(), /*static_num_threads=*/staticNumThreadsAttr, /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), /*mapping=*/mapping); } -// Given a list of OpFoldResults that are either index attrs or op -// handles, return a list of OpFoldResults where all op handles are -// replaced with the first (and only) OpResult of that payload op. (There -// must be exactly one mapped payload op and it must have exactly one -// index result.) +/// Assuming that `ofr` is an index attr or a transform dialect handle mapped +/// to exactly one op with one index result, return that value. static DiagnosedSilenceableFailure unpackPDLOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { - // Don't try to unpack non-PDL operation. - if (ofr.is() || - !ofr.get().getType().isa()) { + if (ofr.is()) { + if (!ofr.get().isa()) + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; } ArrayRef payloadOps = state.getPayloadOps(ofr.get()); - for (Operation *op : payloadOps) { - if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "payload op must have exactly 1 index result"; - diag.attachNote(op->getLoc()) - << "has " << op->getNumResults() << " results"; - return diag; - } - result.push_back(op->getResult(0)); + if (payloadOps.size() != 1) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(ofr.get().getLoc()) + << "mapped to " << payloadOps.size() << " payload ops"; + return diag; } + + Operation *op = payloadOps[0]; + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + + return DiagnosedSilenceableFailure::success(); +} + +// Given a list of OpFoldResults that are either index attrs or op +// handles, return a list of OpFoldResults where all op handles are +// replaced with the first (and only) OpResult of that payload op. (There +// must be exactly one mapped payload op and it must have exactly one +// index result.) +static DiagnosedSilenceableFailure +unpackPDLOperations(transform::TransformState &state, + TransformOpInterface transformOp, + SmallVector &result, Value packedHandle) { + ArrayRef payloadOps = state.getPayloadOps(packedHandle); + for (Operation *op : payloadOps) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); } return DiagnosedSilenceableFailure::success(); @@ -1604,21 +1639,6 @@ if (targets.empty()) return DiagnosedSilenceableFailure::success(); - // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. - // Convert to OpFoldResults[index attributes or payload op]. - SmallVector numThreads; - DiagnosedSilenceableFailure status = - unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads); - if (!status.succeeded()) - return status; - - // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. - // Convert to OpFoldResults[index attributes or payload op]. - SmallVector tileSizes; - status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes); - if (!status.succeeded()) - return status; - // Transform all targets one by one. for (Operation *target : targets) { auto tilableOp = dyn_cast(target); @@ -1633,10 +1653,10 @@ FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, - numThreads, mapping); + mixedNumThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, tileSizes, mapping); + rewriter, tilableOp, mixedTileSizes, mapping); } if (failed(tilingResult)) @@ -1653,16 +1673,35 @@ transform::TransformResults &transformResults, transform::TransformState &state) { IRRewriter rewriter(getContext()); + auto transformOp = cast(getOperation()); ArrayRef targets = state.getPayloadOps(getTarget()); // Result payload ops. SmallVector tileOps; SmallVector tiledOps; + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackPDLOperations(state, transformOp, mixedNumThreads, + getPackedNumThreads()) + : unpackPDLOperations(state, transformOp, mixedNumThreads, + getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackPDLOperations(state, transformOp, mixedTileSizes, + getPackedTileSizes()) + : unpackPDLOperations(state, transformOp, mixedTileSizes, + getMixedTileSizes()); + if (!status.succeeded()) + return status; + DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( - rewriter, state, cast(getOperation()), targets, - getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, - tiledOps); + rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes, + getMapping(), tileOps, tiledOps); if (!diag.succeeded()) { transformResults.set(getForeachThreadOp().cast(), {}); @@ -1695,8 +1734,19 @@ } LogicalResult TileToForeachThreadOp::verify() { - if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) - return emitOpError("either num_threads or tile_sizes must be specified"); + int numThreadsSpec = static_cast(!getMixedNumThreads().empty()) + + static_cast(getPackedNumThreads() != Value()); + if (numThreadsSpec > 1) + return emitOpError( + "num_threads and packed_num_threads are mutually exclusive"); + int tileSizesSpec = static_cast(!getMixedTileSizes().empty()) + + static_cast(getPackedTileSizes() != Value()); + if (tileSizesSpec > 1) + return emitOpError( + "tile_sizes and packed_tile_sizes are mutually exclusive"); + if (numThreadsSpec == 0 && tileSizesSpec == 0) + return emitOpError( + "either (packed_)num_threads or (packed_)tile_sizes must be specified"); return success(); } diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/Utils/CMakeLists.txt b/mlir/lib/Dialect/Transform/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/Utils/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRTransformDialectUtils + Utils.cpp + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Transform + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRSupport + MLIRTransformDialect +) diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp @@ -0,0 +1,44 @@ +//===- Utils.cpp - Transform dialect utilities ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/Utils/Utils.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +using namespace mlir; +using namespace mlir::transform; + +void transform::printPackedOrDynamicIndexList(OpAsmPrinter &printer, + Operation *op, Value packed, + OperandRange values, + ArrayRef integers) { + if (packed) { + assert(values.empty() && integers.empty() && "expected no values/integers"); + printer << packed; + return; + } + printDynamicIndexList(printer, op, values, integers); +} + +ParseResult transform::parsePackedOrDynamicIndexList( + OpAsmParser &parser, std::optional &packed, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers) { + OpAsmParser::UnresolvedOperand packedOperand; + if (parser.parseOptionalOperand(packedOperand).has_value()) { + packed.emplace(std::move(packedOperand)); + integers = parser.getBuilder().getDenseI64ArrayAttr({}); + return success(); + } + return parseDynamicIndexList(parser, values, integers); +} diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -78,7 +78,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %sz = transform.structured.match ops{["test.dummy"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz] + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes %sz } // ----- diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8219,6 +8219,7 @@ ":SideEffectInterfaces", ":TilingInterface", ":TransformDialect", + ":TransformDialectUtils", ":TransformUtils", "//llvm:Support", ], @@ -9096,6 +9097,21 @@ ], ) +cc_library( + name = "TransformDialectUtils", + srcs = ["lib/Dialect/Transform/Utils/Utils.cpp"], + hdrs = ["include/mlir/Dialect/Transform/Utils/Utils.h"], + includes = ["include"], + deps = [ + ":DialectUtils", + ":IR", + ":Support", + ":TransformDialect", + ":ViewLikeInterface", + "//llvm:Support", + ], +) + td_library( name = "ComplexOpsTdFiles", srcs = [