diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -25,7 +25,9 @@ /// TransformState. class TransformOptions { public: - TransformOptions() {} + TransformOptions() = default; + TransformOptions(const TransformOptions &) = default; + TransformOptions &operator=(const TransformOptions &) = default; /// Requests computationally expensive checks of the transform and payload IR /// well-formedness to be performed before each transformation. In particular, @@ -200,7 +202,8 @@ assert(res.second && "the region scope is already present"); (void)res; #if LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(state.regionStack.back()->isProperAncestor(®ion) && + assert(((state.regionStack.size() == 1 && !state.regionStack.back()) || + state.regionStack.back()->isProperAncestor(®ion)) && "scope started at a non-nested region"); state.regionStack.push_back(®ion); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -0,0 +1,164 @@ +//===- TransformInterpreterPassBase.h ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Base class with shared implementation for transform dialect interpreter +// passes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H +#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include + +namespace mlir { +struct LogicalResult; +class MLIRContext; +class ModuleOp; +class Operation; +template +class OwningOpRef; +class Region; + +namespace transform { +namespace detail { +/// Template-free implementation of TransformInterpreterPassBase::initialize. +LogicalResult +interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName, + std::shared_ptr> &module); + +/// Template-free implementation of +/// TransformInterpreterPassBase::runOnOperation. +LogicalResult interpreterBaseRunOnOperationImpl( + Operation *target, StringRef passName, + const std::shared_ptr> &sharedTransformModule, + ArrayRef> extraMappings, + const TransformOptions &options, + const Pass::Option &transformFileName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + StringRef binaryName); +} // namespace detail + +/// Base class for transform dialect interpreter passes that can consume and +/// dump transform dialect scripts in separate files. The pass is controlled by +/// three string options: +/// +/// - transformFileName: if non-empty, the name of the file containing the +/// transform script. If empty, `debugTransformRootTag` is considered or the +/// pass root operation must contain a single top-level transform op that +/// will be interpreted. +/// - debugPayloadRootTag: if non-empty, the value of the attribute named +/// `kTransformDialectTagAttrName` indicating the single op that is +/// considered the payload root of the transform interpreter; otherwise, the +/// root operation of the pass is used. +/// - debugTransformRootTag: if non-empty, the value of the attribute named +/// `kTransformDialectTagAttrName` indicating the single top-level transform +/// op contained in the payload root to be used as the entry point by the +/// transform interpreter; mutually exclusive with `transformFileName`. +/// +/// The pass runs the transform dialect interpreter as directed by the options. +/// It also provides the mechanism to dump reproducers into stderr +/// (-debug-only=transform-dialect-dump-repro) or into a temporary file +/// (-debug-only=transform-dialect-save-repro) that can be used with this +/// pass in a standalone mode. +/// +/// Concrete passes must derive from this class instead of their generated base +/// class (or PassWrapper), and supply themselves and the generated base class +/// as template arguments. They are *not* expected to to implement `initialize` +/// or `runOnOperation`. They *are* expected to call the copy constructor of +/// this class in their copy constructors, short of which the file-based +/// transform dialect script injection facility will become nonoperational. +/// +/// Concrete passes may implement the `runBeforeInterpreter` and +/// `runAfterInterpreter` to customize the behavior of the pass. +template typename GeneratedBase> +class TransformInterpreterPassBase : public GeneratedBase { +public: + explicit TransformInterpreterPassBase( + const TransformOptions &options = TransformOptions()) + : options(options) {} + + TransformInterpreterPassBase(const TransformInterpreterPassBase &pass) { + sharedTransformModule = pass.sharedTransformModule; + options = pass.options; + } + + static StringLiteral getBinaryName() { return "mlir-opt"; } + + LogicalResult initialize(MLIRContext *context) override { + +#define REQUIRE_PASS_OPTION(NAME) \ + static_assert( \ + std::is_same_v< \ + std::remove_reference_t().NAME)>, \ + Pass::Option>, \ + "required " #NAME " string pass option is missing") + + REQUIRE_PASS_OPTION(transformFileName); + REQUIRE_PASS_OPTION(debugPayloadRootTag); + REQUIRE_PASS_OPTION(debugTransformRootTag); + +#undef REQUIRE_PASS_OPTION + + StringRef transformFileName = + static_cast(this)->transformFileName; + return detail::interpreterBaseInitializeImpl(context, transformFileName, + sharedTransformModule); + } + + /// Hook for passes to run additional logic in the pass before the + /// interpreter. If failure is returned, the pass fails and the interpreter is + /// not run. + LogicalResult runBeforeInterpreter(Operation *) { return success(); } + + /// Hook for passes to run additional logic in the pass after the interpreter. + /// Only runs if everything succeeded before. If failure is returned, the pass + /// fails. + LogicalResult runAfterInterpreter(Operation *) { return success(); } + + void runOnOperation() override { + auto *pass = static_cast(this); + Operation *op = pass->getOperation(); + StringRef binaryName = Concrete::getBinaryName(); + if (failed(pass->runBeforeInterpreter(op)) || + failed(detail::interpreterBaseRunOnOperationImpl( + op, pass->getArgument(), sharedTransformModule, + /*extraMappings=*/{}, options, pass->transformFileName, + pass->debugPayloadRootTag, pass->debugTransformRootTag, + binaryName)) || + failed(pass->runAfterInterpreter(op))) { + return pass->signalPassFailure(); + } + } + +protected: + /// Transform interpreter options. + TransformOptions options; + + /// Returns a read-only reference to shared transform module. + const std::shared_ptr> & + getSharedTransformModule() const { + return sharedTransformModule; + } + +private: + /// The separate transform module to be used for transformations, shared + /// across multiple instances of the pass if it is applied in parallel to + /// avoid potentially expensive cloning. MUST NOT be modified after the pass + /// has been initialized. + std::shared_ptr> sharedTransformModule = nullptr; +}; + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OwningOpRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" @@ -828,6 +830,25 @@ } #endif // NDEBUG + // If the transform dialect may use PDL which may modify the IR, clone it + // before use to avoid concurrent modification in case this is being called + // from pass instances running concurrently with a shared transform script. + auto *pdlDialect = + transform->getContext()->getLoadedDialect(); + bool hasPDL = transform + .walk([pdlDialect](Operation *op) { + if (op->getDialect() == pdlDialect) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); + + OwningOpRef owningCopy; + if (hasPDL) { + owningCopy = OwningOpRef(transform->clone()); + transform = owningCopy.get(); + } + TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, options); return state.applyTransform(transform).checkAndReport(); diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms CheckUses.cpp + TransformInterpreterPassBase.cpp DEPENDS MLIRTransformDialectTransformsIncGen diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -0,0 +1,349 @@ +//===- TransformInterpreterPassBase.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Base class with shared implementation for transform dialect interpreter +// passes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +#define DEBUG_TYPE "transform-dialect-interpreter" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") +#define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro" +#define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro" + +/// Name of the attribute used for targeting the transform dialect interpreter +/// at specific operations. +constexpr static llvm::StringLiteral kTransformDialectTagAttrName = + "transform.target_tag"; +/// Value of the attribute indicating the root payload operation. +constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = + "payload_root"; +/// Value of the attribute indicating the container of transform operations +/// (containing the top-level transform operation). +constexpr static llvm::StringLiteral + kTransformDialectTagTransformContainerValue = "transform_container"; + +/// Utility to parse the content of a `transformFileName` MLIR file containing +/// a transform dialect specification. +static LogicalResult +parseTransformModuleFromFile(MLIRContext *context, + llvm::StringRef transformFileName, + OwningOpRef &transformModule) { + if (transformFileName.empty()) { + LLVM_DEBUG( + DBGS() << "no transform file name specified, assuming the transform " + "module is embedded in the IR next to the top-level\n"); + return success(); + } + // Parse transformFileName content into a ModuleOp. + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage); + if (!memoryBuffer) { + return emitError(FileLineColLoc::get( + StringAttr::get(context, transformFileName), 0, 0)) + << "failed to parse transform file"; + } + // Tell sourceMgr about this buffer, the parser will pick it up. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + transformModule = + OwningOpRef(parseSourceFile(sourceMgr, context)); + return success(); +} + +/// Finds the single top-level transform operation with `root` as ancestor. +/// Reports an error if there is more than one such operation and returns the +/// first one found. Reports an error returns nullptr if no such operation +/// found. +static Operation *findTopLevelTransform(Operation *root, + StringRef filenameOption) { + ::mlir::transform::TransformOpInterface topLevelTransform = nullptr; + WalkResult walkResult = root->walk( + [&](::mlir::transform::TransformOpInterface transformOp) { + if (!topLevelTransform) { + topLevelTransform = transformOp; + return WalkResult::skip(); + } + auto diag = transformOp.emitError() + << "more than one top-level transform op"; + diag.attachNote(topLevelTransform.getLoc()) + << "previous top-level transform op"; + return WalkResult::interrupt(); + }); + if (walkResult.wasInterrupted()) + return nullptr; + if (!topLevelTransform) { + auto diag = root->emitError() + << "could not find a nested top-level transform op"; + diag.attachNote() << "use the '" << filenameOption + << "' option to provide transform as external file"; + return nullptr; + } + return topLevelTransform; +} + +/// Finds an operation nested in `root` that has the transform dialect tag +/// attribute with the value specified as `tag`. Assumes only one operation +/// may have the tag. Returns nullptr if there is no such operation. +static Operation *findOpWithTag(Operation *root, StringRef tagKey, + StringRef tagValue) { + Operation *found = nullptr; + WalkResult walkResult = root->walk( + [tagKey, tagValue, &found, root](Operation *op) { + auto attr = op->getAttrOfType(tagKey); + if (!attr || attr.getValue() != tagValue) + return WalkResult::advance(); + + if (found) { + InFlightDiagnostic diag = root->emitError() + << "more than one operation with " << tagKey + << "=\"" << tagValue << "\" attribute"; + diag.attachNote(found->getLoc()) << "first operation"; + diag.attachNote(op->getLoc()) << "other operation"; + return WalkResult::interrupt(); + } + + found = op; + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return nullptr; + + if (!found) { + root->emitError() << "could not find the operation with " << tagKey << "=\"" + << tagValue << "\" attribute"; + } + return found; +} + +/// Returns the ancestor of `target` that doesn't have a parent. +static Operation *getRootOperation(Operation *target) { + Operation *root = target; + while (root->getParentOp()) + root = root->getParentOp(); + return root; +} + +/// Prints the CLI command running the repro with the current path. +// TODO: make binary name optional by querying LLVM command line API for the +// name of the current binary. +static llvm::raw_ostream & +printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + StringRef binaryName) { + os << llvm::formatv( + "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, + passName, debugPayloadRootTag.getArgStr(), + debugPayloadRootTag.empty() + ? StringRef(kTransformDialectTagPayloadRootValue) + : debugPayloadRootTag, + debugTransformRootTag.getArgStr(), + debugTransformRootTag.empty() + ? StringRef(kTransformDialectTagTransformContainerValue) + : debugTransformRootTag, + binaryName); + return os; +} + +/// Prints the module rooted at `root` to `os` and appends +/// `transformContainer` if it is not nested in `root`. +llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, + Operation *transform) { + root->print(os); + if (!root->isAncestor(transform)) + transform->print(os); + return os; +} + +/// Saves the payload and the transform IR into a temporary file and reports +/// the file name to `os`. +void saveReproToTempFile(llvm::raw_ostream &os, Operation *target, + Operation *transform, StringRef passName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + StringRef binaryName) { + using llvm::sys::fs::TempFile; + Operation *root = getRootOperation(target); + + SmallVector tmpPath; + llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath); + llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir"); + llvm::Expected tempFile = TempFile::create(tmpPath); + if (!tempFile) { + os << "could not open temporary file to save the repro\n"; + return; + } + + llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false); + printModuleForRepro(fout, root, transform); + fout.flush(); + std::string filename = tempFile->TmpName; + + if (tempFile->keep()) { + os << "could not preserve the temporary file with the repro\n"; + return; + } + + os << "=== Transform Interpreter Repro ===\n"; + printReproCall(os, root->getName().getStringRef(), passName, + debugPayloadRootTag, debugTransformRootTag, binaryName) + << " " << filename << "\n"; + os << "===================================\n"; +} + +// Optionally perform debug actions requested by the user to dump IR and a +// repro to stderr and/or a file. +static void performOptionalDebugActions( + Operation *target, Operation *transform, StringRef passName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + StringRef binaryName) { + MLIRContext *context = target->getContext(); + + // If we are not planning to print, bail early. + bool hasDebugFlags = false; + DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; }); + DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; }); + if (!hasDebugFlags) + return; + + // We will be mutating the IR to set attributes. If this is running + // concurrently on several parts of a container or using a shared transform + // script, this would create a race. Bail in multithreaded mode and require + // the user to disable threading to dump repros. + static llvm::sys::SmartMutex dbgStreamMutex; + if (target->getContext()->isMultithreadingEnabled()) { + llvm::sys::SmartScopedLock lock(dbgStreamMutex); + llvm::dbgs() << "=======================================================\n"; + llvm::dbgs() << "| Transform reproducers cannot be produced |\n"; + llvm::dbgs() << "| in multi-threaded mode! |\n"; + llvm::dbgs() << "=======================================================\n"; + return; + } + + Operation *root = getRootOperation(target); + + // Add temporary debug / repro attributes, these must never leak out. + if (debugPayloadRootTag.empty()) { + target->setAttr( + kTransformDialectTagAttrName, + StringAttr::get(context, kTransformDialectTagPayloadRootValue)); + } + if (debugTransformRootTag.empty()) { + transform->setAttr( + kTransformDialectTagAttrName, + StringAttr::get(context, kTransformDialectTagTransformContainerValue)); + } + + DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { + llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; + printReproCall(llvm::dbgs() << "cat <getName().getStringRef(), passName, + debugPayloadRootTag, debugTransformRootTag, binaryName) + << "\n"; + printModuleForRepro(llvm::dbgs(), root, transform); + llvm::dbgs() << "\nEOF\n"; + llvm::dbgs() << "===================================\n"; + }); + DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { + saveReproToTempFile(llvm::dbgs(), target, transform, passName, + debugPayloadRootTag, debugTransformRootTag, binaryName); + }); +} + +LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( + Operation *target, StringRef passName, + const std::shared_ptr> &sharedTransformModule, + ArrayRef> extraMappings, + const TransformOptions &options, + const Pass::Option &transformFileName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + StringRef binaryName) { + + // Step 1 + // ------ + // If debugPayloadRootTag was passed, then we are in user-specified selection + // of the transformed IR. This corresponds to REPL debug mode. Otherwise, just + // apply to `target`. + Operation *payloadRoot = target; + if (!debugPayloadRootTag.empty()) { + payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName, + debugPayloadRootTag); + if (!payloadRoot) + return failure(); + } + + // Step 2 + // ------ + // If a shared transform was specified separately, use it. Otherwise, the + // transform is embedded in the payload IR. If debugTransformRootTag was + // passed, then we are in user-specified selection of the transforming IR. + // This corresponds to REPL debug mode. + bool sharedTransform = (sharedTransformModule && *sharedTransformModule); + Operation *transformContainer = + sharedTransform ? sharedTransformModule->get() : target; + Operation *transformRoot = + debugTransformRootTag.empty() + ? findTopLevelTransform(transformContainer, + transformFileName.getArgStr()) + : findOpWithTag(transformContainer, kTransformDialectTagAttrName, + debugTransformRootTag); + if (!transformRoot) + return failure(); + + if (!transformRoot->hasTrait()) { + return emitError(transformRoot->getLoc()) + << "expected the transform entry point to be a top-level transform " + "op"; + } + + // Step 3 + // ------ + // Optionally perform debug actions requested by the user to dump IR and a + // repro to stderr and/or a file. + performOptionalDebugActions(target, transformRoot, passName, + debugPayloadRootTag, debugTransformRootTag, + binaryName); + + // Step 4 + // ------ + // Apply the transform to the IR + return applyTransforms(payloadRoot, cast(transformRoot), + extraMappings, options); +} + +LogicalResult transform::detail::interpreterBaseInitializeImpl( + MLIRContext *context, StringRef transformFileName, + std::shared_ptr> &module) { + OwningOpRef parsed; + if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) + return failure(); + + module = std::make_shared>(std::move(parsed)); + return success(); +} diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s + +func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> { + %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32> + return %0 : tensor<1x1x128x64xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8] +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> +// CHECK: func.func @KCRSsr_to_KCRS +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] = +// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] = +// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) +// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]]) +// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: permutation = [1, 0] +// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} + +// ----- + +func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> { + %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32> + return %0 : tensor<13x15xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)> +// CHECK: func.func @unpack_and_extract_slice +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] = +// CHECK: %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] = +// CHECK: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]]) +// CHECK: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]]) +// CHECK: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] +// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}} +// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>) +// CHECK-SAME: permutation = [0, 1] +// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]] +// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]] +// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] +// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}} +// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2] +} + +// ----- + +func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> +// CHECK: func.func @CKkc_to_KC +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] = +// CHECK: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] = +// CHECK: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) +// CHECK: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) +// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: permutation = [0, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}} +// CHECK-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1] + + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8] +} diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir --- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir @@ -55,114 +55,3 @@ // They have the same type, so the insert_slice op is folded // away. // CHECK: return %[[TRANSP]] - -// ----- - -// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s --check-prefix=CHECK-TRANS - -func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> { - %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32> - return %0 : tensor<1x1x128x64xf32> -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8] -} -// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-TRANS: func.func @KCRSsr_to_KCRS -// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] -// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-TRANS: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) -// CHECK-TRANS: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]]) -// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] -// CHECK-TRANS-SAME: [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] -// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-TRANS-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32> -// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> -// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose -// CHECK-TRANS-SAME: ins(%[[TILE]] -// CHECK-TRANS-SAME: outs(%[[EMPTY]] -// CHECK-TRANS-SAME: permutation = [1, 0] -// CHECK-TRANS: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} - -// ----- - -func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> { - %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32> - return %0 : tensor<13x15xf32> -} -// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)> -// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)> -// CHECK-TRANS-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-TRANS-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-TRANS: func.func @unpack_and_extract_slice -// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] -// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-TRANS: %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]]) -// CHECK-TRANS: %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]]) -// CHECK-TRANS: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]]) -// CHECK-TRANS: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]]) -// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] -// CHECK-TRANS-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] -// CHECK-TRANS: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}} -// CHECK-TRANS-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] -// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32> -// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> -// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose -// CHECK-TRANS-SAME: ins(%[[TILE]] : tensor<8x2xf32>) -// CHECK-TRANS-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>) -// CHECK-TRANS-SAME: permutation = [0, 1] -// CHECK-TRANS: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]] -// CHECK-TRANS-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK-TRANS: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]] -// CHECK-TRANS-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK-TRANS: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}} -// CHECK-TRANS-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2] -} - -// ----- - -func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { - %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32> - return %0 : tensor<128x256xf32> -} -// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-TRANS: func.func @CKkc_to_KC -// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] -// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-TRANS: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] = -// CHECK-TRANS: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-TRANS: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) -// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] -// CHECK-TRANS-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32> -// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> -// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose -// CHECK-TRANS-SAME: ins(%[[TILE]] -// CHECK-TRANS-SAME: outs(%[[EMPTY]] -// CHECK-TRANS-SAME: permutation = [0, 1] -// CHECK-TRANS: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}} -// CHECK-TRANS-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1] - - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8] -} diff --git a/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s + +func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32> + return %0 : tensor<1x1x4x8x8x32xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 64, 8)> +// CHECK: func.func @KCRS_to_KCRSsr +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] = +// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] = +// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) +// CHECK: %[[IN_R_SZ:.+]] = affine.min #[[MAP1]](%[[R]]) +// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]]) +// CHECK: %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]]) +// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: permutation = [1, 0] +// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1] +} + +// ----- + +func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %arg2: f32) -> tensor<2x8x8x2xf32> { + %0 = tensor.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} +// CHECK: func.func @pad_and_pack +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]] +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[SRC_SLICE]] = tensor.extract_slice %[[SRC]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]] +// CHECK: tensor.yield %[[PAD_VAL]] +// CHECK: } : tensor to tensor<8x2xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>) +// CHECK-SAME: permutation = [0, 1] +// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1] +} + +// ----- + + +func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32> + return %0 : tensor<32x4x32x8xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 8)> +// CHECK: func.func @KC_to_CKkc +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] = +// CHECK: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] = +// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) +// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]]) +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]]) +// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-SAME: [0, 0] [32, 8] [1, 1] : tensor to tensor<32x8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: permutation = [0, 1] +// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}} +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32> +// CHECK: %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}} +// CHECK-SAME: [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32> +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1] +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s -// ----- - func.func @dot(%x: memref>, %y: memref>, %v: memref) { diff --git a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=payload debug-transform-root-tag=transform})" \ +// RUN: --allow-unregistered-dialect --split-input-file --verify-diagnostics + +// expected-error @below {{could not find the operation with transform.target_tag="payload" attribute}} +module { + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + } +} + +// ----- + +// expected-error @below {{could not find the operation with transform.target_tag="transform" attribute}} +module { + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + } + + module attributes {transform.target_tag="payload"} {} +} + +// ----- + +// expected-error @below {{more than one operation with transform.target_tag="transform" attribute}} +module { + // expected-note @below {{first operation}} + transform.sequence failures(propagate) attributes {transform.target_tag="transform"} { + ^bb0(%arg0: !transform.any_op): + } + + // expected-note @below {{other operation}} + transform.sequence failures(propagate) attributes {transform.target_tag="transform"} { + ^bb0(%arg0: !transform.any_op): + } + + module attributes {transform.target_tag="payload"} {} +} + +// ----- + +module { + // expected-error @below {{expected the transform entry point to be a top-level transform op}} + func.func private @foo() attributes {transform.target_tag="transform"} + + module attributes {transform.target_tag="payload"} {} +} + +// ----- + +module { + transform.sequence failures(suppress) attributes {transform.target_tag="transform"} { + ^bb0(%arg0: !transform.any_op): + transform.test_print_remark_at_operand %arg0, "payload" : !transform.any_op + } + + // This will not be executed because it's not tagged. + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + transform.test_print_remark_at_operand %arg0, "some other text that is not printed" : !transform.any_op + } + + module { + module {} + // expected-remark @below {{payload}} + module attributes {transform.target_tag="payload"} {} + module {} + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s +// No need to check anything else than parsing here, this is being used by another test as data. + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + transform.test_print_remark_at_operand %arg0, "outer" : !transform.any_op + transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} { + ^bb1(%arg1: !transform.any_op): + transform.test_print_remark_at_operand %arg1, "inner" : !transform.any_op + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external.mlir b/mlir/test/Dialect/Transform/test-interpreter-external.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-source.mlir})" \ +// RUN: --verify-diagnostics + +// The schedule in the separate file emits remarks at the payload root. + +// expected-remark @below {{outer}} +// expected-remark @below {{inner}} +module {} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1091,3 +1091,22 @@ // expected-error @below {{attempting to assign a null parameter to this transform value}} %0 = transform.test_produce_null_param : !transform.param } + +// ----- + +// expected-error @below {{could not find a nested top-level transform op}} +// expected-note @below {{use the 'transform-file-name' option to provide transform as external file}} +module { +} + +// ----- + +// expected-note @below {{previous top-level transform op}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): +} + +// expected-error @below {{ore than one top-level transform op}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): +} diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -19,4 +19,5 @@ MLIRPass MLIRPDLDialect MLIRTransformDialect + MLIRTransformDialectTransforms ) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -21,16 +22,22 @@ namespace { /// Simple pass that applies transform dialect ops directly contained in a /// module. + +template +class ModulePassWrapper : public PassWrapper> { +}; + class TestTransformDialectInterpreterPass - : public PassWrapper> { + : public transform::TransformInterpreterPassBase< + TestTransformDialectInterpreterPass, ModulePassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectInterpreterPass) TestTransformDialectInterpreterPass() = default; TestTransformDialectInterpreterPass( - const TestTransformDialectInterpreterPass &) {} + const TestTransformDialectInterpreterPass &pass) + : TransformInterpreterPassBase(pass) {} StringRef getArgument() const override { return "test-transform-dialect-interpreter"; @@ -101,15 +108,12 @@ getContext(), bindSecondExtraToParams, extraMappingStorage)); } - ModuleOp module = getOperation(); - for (auto op : - module.getBody()->getOps()) { - if (failed(transform::applyTransforms( - module, op, extraMapping, - transform::TransformOptions().enableExpensiveChecks( - enableExpensiveChecks)))) - return signalPassFailure(); - } + options = options.enableExpensiveChecks(enableExpensiveChecks); + if (failed(transform::detail::interpreterBaseRunOnOperationImpl( + getOperation(), getArgument(), getSharedTransformModule(), + extraMapping, options, transformFileName, debugPayloadRootTag, + debugTransformRootTag, getBinaryName()))) + return signalPassFailure(); } Option enableExpensiveChecks{ @@ -134,6 +138,25 @@ *this, "bind-second-extra-to-params", llvm::cl::desc("bind the second extra argument of the top-level op to " "the given integer parameters")}; + Option transformFileName{ + *this, "transform-file-name", llvm::cl::init(""), + llvm::cl::desc( + "Optional filename containing a transform dialect specification to " + "apply. If left empty, the IR is assumed to contain one top-level " + "transform dialect operation somewhere in the module.")}; + Option debugPayloadRootTag{ + *this, "debug-payload-root-tag", llvm::cl::init(""), + llvm::cl::desc( + "Select the operation with 'transform.target_tag' attribute having " + "the given value as payload IR root. If empty select the pass anchor " + "operation as the payload IR root.")}; + Option debugTransformRootTag{ + *this, "debug-transform-root-tag", llvm::cl::init(""), + llvm::cl::desc( + "Select the operation with 'transform.target_tag' attribute having " + "the given value as container IR for top-level transform ops. This " + "allows user control on what transformation to apply. If empty, " + "select the container of the top-level transform op.")}; }; struct TestTransformDialectEraseSchedulePass diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9204,8 +9204,10 @@ deps = [ ":Analysis", ":IR", + ":Parser", ":Pass", ":SideEffectInterfaces", + ":Support", ":TransformDialect", ":TransformDialectTransformsIncGen", "//llvm:Support", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -319,6 +319,7 @@ "//mlir:PDLDialect", "//mlir:Pass", "//mlir:TransformDialect", + "//mlir:TransformDialectTransforms", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel @@ -13,7 +13,12 @@ "//mlir:mlir-opt", "//mlir:mlir-translate", "//mlir/test:lit_data", - ], + ] + glob([ + "Transform/*-source.mlir", + ]) + ) + for src in glob( + include=["**/*.mlir"], + exclude=["Transform/*-source.mlir"] ) - for src in glob(["**/*.mlir"]) ]