diff --git a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt @@ -7,3 +7,5 @@ add_dependencies(mlir-headers MLIRSCFPassIncGen) add_mlir_doc(Passes SCFPasses ./ -gen-pass-doc) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/SCF/Patterns.h b/mlir/include/mlir/Dialect/SCF/Patterns.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/Patterns.h @@ -0,0 +1,54 @@ +//===- Patterns.h - SCF dialect rewrite patterns ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_PATTERNS_H +#define MLIR_DIALECT_SCF_PATTERNS_H + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace scf { +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +class ForLoopPipeliningPattern : public OpRewritePattern { +public: + ForLoopPipeliningPattern(const PipeliningOption &options, + MLIRContext *context) + : OpRewritePattern(context), options(options) {} + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(forOp, rewriter); + } + + FailureOr returningMatchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const; + +protected: + PipeliningOption options; +}; + +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_PATTERNS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS SCFTransformOps.td) +mlir_tablegen(SCFTransformOps.h.inc -gen-op-decls) +mlir_tablegen(SCFTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRSCFTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -0,0 +1,36 @@ +//===- SCFTransformOps.h - SCF transformation ops ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H +#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func +namespace scf { +class ForOp; +} // namespace scf +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace scf { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -0,0 +1,144 @@ +//===- SCFTransformOps.td - SCF (loop) transformation ops --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SCF_TRANSFORM_OPS +#define SCF_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformEffects.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def GetParentForOp : Op]> { + let summary = "Gets a handle to the parent 'for' loop of the given operation"; + let description = [{ + Produces a handle to the n-th (default 1) parent `scf.for` loop for each + Payload IR operation associated with the operand. Fails if such a loop + cannot be found. The list of operations associated with the handle contains + parent operations in the same order as the list associated with the operand, + except for operations that are parents to more than one input which are only + present once. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr, + "1">:$num_loops); + let results = (outs PDL_Operation:$parent); + + let assemblyFormat = "$target attr-dict"; +} + +def LoopOutlineOp : Op]> { + let summary = "Outlines a loop into a named function"; + let description = [{ + Moves the loop into a separate function with the specified name and + replaces the loop in the Payload IR with a call to that function. Takes + care of forwarding values that are used in the loop as function arguments. + If the operand is associated with more than one loop, each loop will be + outlined into a separate function. The provided name is used as a _base_ + for forming actual function names following SymbolTable auto-renaming + scheme to avoid duplicate symbols. Expects that all ops in the Payload IR + have a SymbolTable ancestor (typically true because of the top-level + module). Returns the handle to the list of outlined functions in the same + order as the operand handle. + }]; + + let arguments = (ins PDL_Operation:$target, + StrAttr:$func_name); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; +} + +def LoopPeelOp : Op { + let summary = "Peels the last iteration of the loop"; + let description = [{ + Updates the given loop so that its step evenly divides its range and puts + the remaining iteration into a separate loop or a conditional. Note that + even though the Payload IR modification may be performed in-place, this + operation consumes the operand handle and produces a new one. Applies to + each loop associated with the operand handle individually. The results + follow the same order as the operand. + + Note: If it can be proven statically that the step already evenly divides + the range, this op is a no-op. In the absence of sufficient static + information, this op may peel a loop, even if the step always divides the + range evenly at runtime. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr:$fail_if_already_divisible); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + }]; +} + +def LoopPipelineOp : Op { + let summary = "Applies software pipelining to the loop"; + let description = [{ + Transforms the given loops one by one to achieve software pipelining for + each of them. That is, performs some amount of reads from memory before the + loop rather than inside the loop, the same amount of writes into memory + after the loop, and updates each iteration to read the data for a following + iteration rather than the current one. The amount is specified by the + attributes. The values read and about to be stored are transferred as loop + iteration arguments. Currently supports memref and vector transfer + operations as memory reads/writes. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$iteration_interval, + DefaultValuedAttr:$read_latency); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + }]; +} + +def LoopUnrollOp : Op { + let summary = "Unrolls the given loop with the given unroll factor"; + let description = [{ + Unrolls each loop associated with the given handle to have up to the given + number of loop body copies per iteration. If the unroll factor is larger + than the loop trip count, the latter is used as the unroll factor instead. + Does not produce a new handle as the operation may result in the loop being + removed after a full unrolling. + }]; + + let arguments = (ins PDL_Operation:$target, + Confined:$factor); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop); + }]; +} + +#endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -158,26 +158,8 @@ // TODO: add option to decide if the prologue should be peeled. }; -/// Populate patterns for SCF software pipelining transformation. -/// This transformation generates the pipelined loop and doesn't do any -/// assumptions on the schedule dictated by the option structure. -/// Software pipelining is usually done in two part. The first part of -/// pipelining is to schedule the loop and assign a stage and cycle to each -/// operations. This is highly dependent on the target and is implemented as an -/// heuristic based on operation latencies, and other hardware characteristics. -/// The second part is to take the schedule and generate the pipelined loop as -/// well as the prologue and epilogue. It is independent of the target. -/// This pattern only implement the second part. -/// For example if we break a loop into 3 stages named S0, S1, S2 we would -/// generate the following code with the number in parenthesis the iteration -/// index: -/// S0(0) // Prologue -/// S0(1) S1(0) // Prologue -/// scf.for %I = %C0 to %N - 2 { -/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel -/// } -/// S1(N) S2(N-1) // Epilogue -/// S2(N) // Epilogue +/// Populate patterns for SCF software pipelining transformation. See the +/// ForLoopPipeliningPattern for the transformation details. void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options); diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -28,6 +28,7 @@ class Value; namespace func { +class CallOp; class FuncOp; } // namespace func @@ -63,12 +64,13 @@ /// `outlinedFuncBody` to alloc simple canonicalizations. /// Creates a new FuncOp and thus cannot be used in a FuncOp pass. /// The client is responsible for providing a unique `funcName` that will not -/// collide with another FuncOp name. +/// collide with another FuncOp name. If `callOp` is provided, it will be set +/// to point to the operation that calls the outlined function. // TODO: support more than single-block regions. // TODO: more flexible constant handling. -FailureOr outlineSingleBlockRegion(RewriterBase &rewriter, - Location loc, Region ®ion, - StringRef funcName); +FailureOr +outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, + StringRef funcName, func::CallOp *callOp = nullptr); /// Outline the then and/or else regions of `ifOp` as follows: /// - if `thenFn` is not null, `thenFnName` must be specified and the `then` diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -451,7 +451,7 @@ StringRef getName() override { return "transform.payload_ir"; } }; -/// Trait implementing the MemoryEffectOpInterface for single-operand +/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or /// single-result operations that "consume" their operand and produce a new /// result. template @@ -468,6 +468,47 @@ effects.emplace_back(MemoryEffects::Free::get(), this->getOperation()->getOperand(0), TransformMappingResource::get()); + if (this->getOperation()->getNumResults() == 1) { + effects.emplace_back(MemoryEffects::Allocate::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); + } + + /// Checks that the op matches the expectations of this trait. + static LogicalResult verifyTrait(Operation *op) { + static_assert(OpTy::template hasTrait(), + "expected single-operand op"); + static_assert(OpTy::template hasTrait() || + OpTy::template hasTrait(), + "expected zero- or single-result op"); + if (!op->getName().getInterface()) { + op->emitError() + << "FunctionalStyleTransformOpTrait should only be attached to ops " + "that implement MemoryEffectOpInterface"; + } + return success(); + } +}; + +/// Trait implementing the MemoryEffectOpInterface for single-operand +/// single-result operations that use their operand without consuming and +/// without modifying the Payload IR to produce a new handle. +template +class NavigationTransformOpTrait + : public OpTrait::TraitBase { +public: + /// This op produces handles to the Payload IR without consuming the original + /// handles and without modifying the IR itself. + void getEffects(SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), + this->getOperation()->getOperand(0), + TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), this->getOperation()->getResult(0), TransformMappingResource::get()); @@ -475,19 +516,17 @@ this->getOperation()->getResult(0), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); } - /// Checks that the op matches the expectations of this trait. + /// Checks that the op matches the expectation of this trait. static LogicalResult verifyTrait(Operation *op) { static_assert(OpTy::template hasTrait(), "expected single-operand op"); static_assert(OpTy::template hasTrait(), "expected single-result op"); if (!op->getName().getInterface()) { - op->emitError() - << "FunctionalStyleTransformOpTrait should only be attached to ops " - "that implement MemoryEffectOpInterface"; + op->emitError() << "NavigationTransformOpTrait should only be attached " + "to ops that implement MemoryEffectOpInterface"; } return success(); } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -47,6 +47,19 @@ "::mlir::transform::TransformState &":$state )>, ]; + + let extraSharedClassDeclaration = [{ + /// Emits a generic transform error for the current transform operation + /// targeting the given Payload IR operation and returns failure. Should + /// be only used as a last resort when the transformation itself provides + /// no further indication as to the reason of the failure. + ::mlir::LogicalResult reportUnknownTransformError( + ::mlir::Operation *target) { + ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply"; + diag.attachNote(target->getLoc()) << "attempted to apply to this op"; + return diag; + } + }]; } def FunctionalStyleTransformOpTrait @@ -58,4 +71,8 @@ let cppNamespace = "::mlir::transform"; } +def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -19,7 +19,7 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { let summary = "Gets handles to the closest isolated-from-above parents"; let description = [{ The handles defined by this Transform op correspond to the closest isolated diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -47,6 +47,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" @@ -107,6 +108,7 @@ // Register all dialect extensions. linalg::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); // Register all external models. arith::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -89,9 +89,7 @@ if (succeeded(depthwise)) return depthwise; - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + return reportUnknownTransformError(target); } //===----------------------------------------------------------------------===// @@ -107,9 +105,7 @@ if (succeeded(generic)) return generic; - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + return reportUnknownTransformError(target); } //===----------------------------------------------------------------------===// @@ -416,11 +412,8 @@ if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { - InFlightDiagnostic diag = emitError() << "failed to apply"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; - } + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return reportUnknownTransformError(target); return target; } diff --git a/mlir/lib/Dialect/SCF/CMakeLists.txt b/mlir/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/CMakeLists.txt @@ -16,5 +16,6 @@ MLIRSideEffectInterfaces ) +add_subdirectory(TransformOps) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRSCFTransformOps + SCFTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF/TransformOps + + DEPENDS + MLIRSCFTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffine + MLIRFunc + MLIRIR + MLIRPDL + MLIRSCF + MLIRSCFTransforms + MLIRSCFUtils + MLIRTransformDialect + MLIRVector +) diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -0,0 +1,232 @@ +//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/Patterns.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; + +namespace { +/// A simple pattern rewriter that implements no special logic. +class SimpleRewriter : public PatternRewriter { +public: + SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GetParentForOp +//===----------------------------------------------------------------------===// + +LogicalResult +transform::GetParentForOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SetVector parents; + for (Operation *target : state.getPayloadOps(getTarget())) { + scf::ForOp loop; + Operation *current = target; + for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { + loop = current->getParentOfType(); + if (!loop) { + InFlightDiagnostic diag = emitError() << "could not find an '" + << scf::ForOp::getOperationName() + << "' parent"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + current = loop; + } + parents.insert(loop); + } + results.set(getResult().cast(), parents.getArrayRef()); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopOutlineOp +//===----------------------------------------------------------------------===// + +/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses +/// the provided rewriter for all operations to remain compatible with the +/// rewriting infra, as opposed to just splicing the op in place. +static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, + Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + scf::ExecuteRegionOp executeRegionOp = + b.create(op->getLoc(), op->getResultTypes()); + { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); + Operation *clonedOp = b.cloneWithoutRegions(*op); + Region &clonedRegion = clonedOp->getRegions().front(); + assert(clonedRegion.empty() && "expected empty region"); + b.inlineRegionBefore(op->getRegions().front(), clonedRegion, + clonedRegion.end()); + b.create(op->getLoc(), clonedOp->getResults()); + } + b.replaceOp(op, executeRegionOp.getResults()); + return executeRegionOp; +} + +LogicalResult +transform::LoopOutlineOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector transformed; + DenseMap symbolTables; + for (Operation *target : state.getPayloadOps(getTarget())) { + Location location = target->getLoc(); + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); + SimpleRewriter rewriter(getContext()); + scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); + if (!exec) { + InFlightDiagnostic diag = emitError() << "failed to outline"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + func::CallOp call; + FailureOr outlined = outlineSingleBlockRegion( + rewriter, location, exec.getRegion(), getFuncName(), &call); + + if (failed(outlined)) + return reportUnknownTransformError(target); + + if (symbolTableOp) { + SymbolTable &symbolTable = + symbolTables.try_emplace(symbolTableOp, symbolTableOp) + .first->getSecond(); + symbolTable.insert(*outlined); + call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); + } + transformed.push_back(*outlined); + } + results.set(getTransformed().cast(), transformed); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopPeelOp +//===----------------------------------------------------------------------===// + +FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop) { + scf::ForOp result; + IRRewriter rewriter(loop->getContext()); + LogicalResult status = + scf::peelAndCanonicalizeForLoop(rewriter, loop, result); + if (failed(status)) { + if (getFailIfAlreadyDivisible()) + return reportUnknownTransformError(loop); + return loop; + } + return result; +} + +//===----------------------------------------------------------------------===// +// LoopPipelineOp +//===----------------------------------------------------------------------===// + +/// Callback for PipeliningOption. Populates `schedule` with the mapping from an +/// operation to its logical time position given the iteration interval and the +/// read latency. The latter is only relevant for vector transfers. +static void +loopScheduling(scf::ForOp forOp, + std::vector> &schedule, + unsigned iterationInterval, unsigned readLatency) { + auto getLatency = [&](Operation *op) -> unsigned { + if (isa(op)) + return readLatency; + return 1; + }; + + DenseMap opCycles; + std::map> wrappedSchedule; + for (Operation &op : forOp.getBody()->getOperations()) { + if (isa(op)) + continue; + unsigned earlyCycle = 0; + for (Value operand : op.getOperands()) { + Operation *def = operand.getDefiningOp(); + if (!def) + continue; + earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); + } + opCycles[&op] = earlyCycle; + wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); + } + for (auto it : wrappedSchedule) { + for (Operation *op : it.second) { + unsigned cycle = opCycles[op]; + schedule.push_back(std::make_pair(op, cycle / iterationInterval)); + } + } +} + +FailureOr transform::LoopPipelineOp::applyToOne(scf::ForOp loop) { + scf::PipeliningOption options; + options.getScheduleFn = + [this](scf::ForOp forOp, + std::vector> &schedule) mutable { + loopScheduling(forOp, schedule, getIterationInterval(), + getReadLatency()); + }; + + scf::ForLoopPipeliningPattern pattern(options, loop->getContext()); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(loop); + FailureOr patternResult = + pattern.returningMatchAndRewrite(loop, rewriter); + if (failed(patternResult)) + return reportUnknownTransformError(loop); + return patternResult; +} + +//===----------------------------------------------------------------------===// +// LoopUnrollOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) { + if (failed(loopUnrollByFactor(loop, getFactor()))) + return reportUnknownTransformError(loop); + return success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class SCFTransformDialectExtension + : public transform::TransformDialectExtension< + SCFTransformDialectExtension> { +public: + SCFTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" + +void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -12,6 +12,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/Patterns.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -436,76 +437,53 @@ it->second[idx] = el; } -/// Generate a pipelined version of the scf.for loop based on the schedule given -/// as option. This applies the mechanical transformation of changing the loop -/// and generating the prologue/epilogue for the pipelining and doesn't make any -/// decision regarding the schedule. -/// Based on the option the loop is split into several stages. -/// The transformation assumes that the scheduling given by user is valid. -/// For example if we break a loop into 3 stages named S0, S1, S2 we would -/// generate the following code with the number in parenthesis the iteration -/// index: -/// S0(0) // Prologue -/// S0(1) S1(0) // Prologue -/// scf.for %I = %C0 to %N - 2 { -/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel -/// } -/// S1(N) S2(N-1) // Epilogue -/// S2(N) // Epilogue -struct ForLoopPipelining : public OpRewritePattern { - ForLoopPipelining(const PipeliningOption &options, MLIRContext *context) - : OpRewritePattern(context), options(options) {} - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override { +} // namespace - LoopPipelinerInternal pipeliner; - if (!pipeliner.initializeLoopInfo(forOp, options)) - return failure(); +FailureOr ForLoopPipeliningPattern::returningMatchAndRewrite( + ForOp forOp, PatternRewriter &rewriter) const { - // 1. Emit prologue. - pipeliner.emitPrologue(rewriter); + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); - // 2. Track values used across stages. When a value cross stages it will - // need to be passed as loop iteration arguments. - // We first collect the values that are used in a different stage than where - // they are defined. - llvm::MapVector - crossStageValues = pipeliner.analyzeCrossStageValues(); + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); - // Mapping between original loop values used cross stage and the block - // arguments associated after pipelining. A Value may map to several - // arguments if its liverange spans across more than 2 stages. - llvm::DenseMap, unsigned> loopArgMap; - // 3. Create the new kernel loop and return the block arguments mapping. - ForOp newForOp = - pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); - // Create the kernel block, order ops based on user choice and remap - // operands. - pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); - llvm::SmallVector returnValues = - newForOp.getResults().take_front(forOp->getNumResults()); - if (options.peelEpilogue) { - // 4. Emit the epilogue after the new forOp. - rewriter.setInsertionPointAfter(newForOp); - returnValues = pipeliner.emitEpilogue(rewriter); - } - // 5. Erase the original loop and replace the uses with the epilogue output. - if (forOp->getNumResults() > 0) - rewriter.replaceOp(forOp, returnValues); - else - rewriter.eraseOp(forOp); + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); - return success(); + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); -protected: - PipeliningOption options; -}; - -} // namespace + return newForOp; +} void mlir::scf::populateSCFLoopPipeliningPatterns( RewritePatternSet &patterns, const PipeliningOption &options) { - patterns.add(options, patterns.getContext()); + patterns.add(options, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -105,13 +105,16 @@ /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. /// This method also clones the `arith::ConstantIndexOp` at the start of -/// `outlinedFuncBody` to alloc simple canonicalizations. +/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is +/// provided, it will be set to point to the operation that calls the outlined +/// function. // TODO: support more than single-block regions. // TODO: more flexible constant handling. FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, - StringRef funcName) { + StringRef funcName, + func::CallOp *callOp) { assert(!funcName.empty() && "funcName cannot be empty"); if (!region.hasOneBlock()) return failure(); @@ -176,8 +179,9 @@ SmallVector callValues; llvm::append_range(callValues, newBlock->getArguments()); llvm::append_range(callValues, outlinedValues); - Operation *call = - rewriter.create(loc, outlinedFunc, callValues); + auto call = rewriter.create(loc, outlinedFunc, callValues); + if (callOp) + *callOp = call; // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. // Clone `originalTerminator` to take the callOp results then erase it from diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -137,17 +137,6 @@ return success(); } -void transform::GetClosestIsolatedParentOp::getEffects( - SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Allocate::get(), getParent(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), getParent(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); -} - //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -125,6 +125,16 @@ dialects/transform/__init__.py DIALECT_NAME transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFLoopTransformOps.td + SOURCES + dialects/_loop_transform_ops_ext.py + dialects/transform/loop.py + DIALECT_NAME transform + EXTENSION_NAME loop_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -0,0 +1,21 @@ +//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the loop transform ops +// provided by the SCF (and other) dialects. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS +#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" + +#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], + default_value: int = None): + if isinstance(arg, IntegerAttr): + return arg + + if arg is None: + assert default_value is not None, "must provide default value" + arg = default_value + + return IntegerAttr.get(IntegerType.get_signless(64), arg) + + +class GetParentForOp: + """Extension for GetParentForOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + num_loops: int = 1, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + num_loops=_get_int64_attr(num_loops, default_value=1), + ip=ip, + loc=loc) + + +class LoopOutlineOp: + """Extension for LoopOutlineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + func_name=(func_name if isinstance(func_name, StringAttr) else + StringAttr.get(func_name)), + ip=ip, + loc=loc) + + +class LoopPeelOp: + """Extension for LoopPeelOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + fail_if_already_divisible=(fail_if_already_divisible if isinstance( + fail_if_already_divisible, BoolAttr) else + BoolAttr.get(fail_if_already_divisible)), + ip=ip, + loc=loc) + + +class LoopPipelineOp: + """Extension for LoopPipelineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + iteration_interval=_get_int64_attr(iteration_interval, default_value=1), + read_latency=_get_int64_attr(read_latency, default_value=10), + ip=ip, + loc=loc) + + +class LoopUnrollOp: + """Extension for LoopUnrollOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None): + super().__init__( + _get_op_result_or_value(target), + factor=_get_int64_attr(factor), + ip=ip, + loc=loc) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._loop_transform_ops_gen import * diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -0,0 +1,264 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @get_parent_for_op +func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { + // expected-remark @below {{first loop}} + scf.for %i = %arg0 to %arg1 step %arg2 { + // expected-remark @below {{second loop}} + scf.for %j = %arg0 to %arg1 step %arg2 { + // expected-remark @below {{third loop}} + scf.for %k = %arg0 to %arg1 step %arg2 { + arith.addi %i, %j : index + } + } + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + // CHECK: = transform.loop.get_parent_for + %1 = transform.loop.get_parent_for %0 + %2 = transform.loop.get_parent_for %0 { num_loops = 2 } + %3 = transform.loop.get_parent_for %0 { num_loops = 3 } + transform.test_print_remark_at_operand %1, "third loop" + transform.test_print_remark_at_operand %2, "second loop" + transform.test_print_remark_at_operand %3, "first loop" + } +} + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-note @below {{target op}} + arith.addi %arg0, %arg1 : index + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + // expected-error @below {{could not find an 'scf.for' parent}} + %1 = transform.loop.get_parent_for %0 + } +} + +// ----- + +// Outlined functions: +// +// CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) +// CHECK: scf.for +// CHECK: arith.addi +// +// CHECK: func @foo[[SUFFIX:.+]](%{{.+}}, %{{.+}}, %{{.+}}) +// CHECK: scf.for +// CHECK: arith.addi +// +// CHECK-LABEL @loop_outline_op +func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) { + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: scf.execute_region + // CHECK: func.call @foo + scf.for %i = %arg0 to %arg1 step %arg2 { + scf.for %j = %arg0 to %arg1 step %arg2 { + arith.addi %i, %j : index + } + } + // CHECK: scf.execute_region + // CHECK-NOT: scf.for + // CHECK: func.call @foo[[SUFFIX]] + scf.for %j = %arg0 to %arg1 step %arg2 { + arith.addi %j, %j : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + // CHECK: = transform.loop.outline %{{.*}} + transform.loop.outline %1 {func_name = "foo"} + } +} + +// ----- + +func.func private @cond() -> i1 +func.func private @body() + +func.func @loop_outline_op_multi_region() { + // expected-note @below {{target op}} + scf.while : () -> () { + %0 = func.call @cond() : () -> i1 + scf.condition(%0) + } do { + ^bb0: + func.call @body() : () -> () + scf.yield + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_while : benefit(1) { + %args = operands + %results = types + %op = operation "scf.while"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_while in %arg1 + // expected-error @below {{failed to outline}} + transform.loop.outline %0 {func_name = "foo"} + } +} + +// ----- + +// CHECK-LABEL: @loop_peel_op +func.func @loop_peel_op() { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[C42:.+]] = arith.constant 42 + // CHECK: %[[C5:.+]] = arith.constant 5 + // CHECK: %[[C40:.+]] = arith.constant 40 + // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]] + // CHECK: arith.addi + // CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]] + // CHECK: arith.addi + %0 = arith.constant 0 : index + %1 = arith.constant 42 : index + %2 = arith.constant 5 : index + scf.for %i = %0 to %1 step %2 { + arith.addi %i, %i : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + transform.loop.peel %1 + } +} + +// ----- + +func.func @loop_pipeline_op(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + // CHECK: memref.load %[[MEMREF:.+]][%{{.+}}] + // CHECK: memref.load %[[MEMREF]] + // CHECK: arith.addf + // CHECK: scf.for + // CHECK: memref.load + // CHECK: arith.addf + // CHECK: memref.store + // CHECK: arith.addf + // CHECK: memref.store + // CHECK: memref.store + // expected-remark @below {{transformed}} + scf.for %i0 = %c0 to %c4 step %c1 { + %A_elem = memref.load %A[%i0] : memref + %A1_elem = arith.addf %A_elem, %cf : f32 + memref.store %A1_elem, %result[%i0] : memref + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addf : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addf"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addf in %arg1 + %1 = transform.loop.get_parent_for %0 + %2 = transform.loop.pipeline %1 + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %2, "transformed" + } +} + +// ----- + +// CHECK-LABEL: @loop_unroll_op +func.func @loop_unroll_op() { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c5 = arith.constant 5 : index + // CHECK: scf.for %[[I:.+]] = + scf.for %i = %c0 to %c42 step %c5 { + // CHECK-COUNT-4: arith.addi %[[I]] + arith.addi %i, %i : index + } + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_addi : benefit(1) { + %args = operands + %results = types + %op = operation "arith.addi"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_addi in %arg1 + %1 = transform.loop.get_parent_for %0 + transform.loop.unroll %1 { factor = 4 } + } +} + diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -0,0 +1,71 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects import pdl +from mlir.dialects.transform import loop + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def getParentLoop(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.GetParentForOp(sequence.bodyTarget, num_loops=2) + transform.YieldOp() + # CHECK-LABEL: TEST: getParentLoop + # CHECK: = transform.loop.get_parent_for % + # CHECK: num_loops = 2 + + +@run +def loopOutline(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo") + transform.YieldOp() + # CHECK-LABEL: TEST: loopOutline + # CHECK: = transform.loop.outline % + # CHECK: func_name = "foo" + + +@run +def loopPeel(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopPeelOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPeel + # CHECK: = transform.loop.peel % + + +@run +def loopPipeline(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPipeline + # CHECK: = transform.loop.pipeline % + # CHECK-DAG: iteration_interval = 3 + # CHECK-DAG: read_latency = 10 + + +@run +def loopUnroll(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + loop.LoopUnrollOp(sequence.bodyTarget, factor=42) + transform.YieldOp() + # CHECK-LABEL: TEST: loopUnroll + # CHECK: transform.loop.unroll % + # CHECK: factor = 42 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1868,6 +1868,7 @@ hdrs = [ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Passes.h", + "include/mlir/Dialect/SCF/Patterns.h", "include/mlir/Dialect/SCF/Transforms.h", ], includes = ["include"], @@ -1892,6 +1893,59 @@ ], ) +td_library( + name = "SCFTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td", + ], + includes = ["include"], + deps = [ + ":PDLDialect", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "SCFTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td", + deps = [ + ":SCFTransformOpsTdFiles", + ], +) + +cc_library( + name = "SCFTransformOps", + srcs = glob(["lib/Dialect/SCF/TransformOps/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/SCF/TransformOps/*.h"]), + includes = ["include"], + deps = [ + ":Affine", + ":FuncDialect", + ":IR", + ":PDLDialect", + ":SCFDialect", + ":SCFTransformOpsIncGen", + ":SCFTransforms", + ":SCFUtils", + ":SideEffectInterfaces", + ":TransformDialect", + ":VectorOps", + "//llvm:Support", + ], +) + ##---------------------------------------------------------------------------## # SparseTensor dialect. ##---------------------------------------------------------------------------## @@ -2601,6 +2655,7 @@ ], exclude = [ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/SCF/Patterns.h", "include/mlir/Dialect/SCF/Transforms.h", ], ), @@ -6299,6 +6354,7 @@ ":SCFPassIncGen", ":SCFToGPUPass", ":SCFToStandard", + ":SCFTransformOps", ":SCFTransforms", ":SDBM", ":SPIRVDialect", diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -878,11 +878,33 @@ ], ) +gentbl_filegroup( + name = "LoopTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=loop_transform", + ], + "mlir/dialects/_loop_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/SCFLoopTransformOps.td", + deps = [ + ":TransformOpsPyTdFiles", + "//mlir:SCFTransformOpsTdFiles", + ], +) + filegroup( name = "TransformOpsPyFiles", srcs = [ + "mlir/dialects/_loop_transform_ops_ext.py", "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", + ":LoopTransformOpsPyGen", ":StructuredTransformOpsPyGen", ":TransformOpsPyGen", ],