diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H #define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td @@ -9,6 +9,7 @@ #ifndef AFFINE_TRANSFORM_OPS #define AFFINE_TRANSFORM_OPS +include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" @@ -18,4 +19,49 @@ def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">; +def SimplifyBoundedAffineOpsOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Simplify the targeted affine.min / affine.max ops given the supplied + lower and upper bounds for values that may be used as target op operands. + + Example: + ``` + %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1 + %1 = transform.structured.match ops{["gpu.lane_id"]} in %arg1 + transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [32] + + // Multiple bounds can be specified. + transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 5] and [32, 50] + ``` + + Bounded op handles (`%1` and `%2) must be mapped to ops that have a single + result of index type. The sets of target ops and bounded ops must not + overlap. + + #### Return modes + + Target ops must be affine.min or affine.max ops. This transform consumes the + target handle and does not produce any handle. It reads the bounded op + handles. + + TODO: Support affine.apply targets. + TODO: Allow mixed PDL_Operation/int64_t for lower_bounds and upper_bounds. + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$bounded_values, + DenseI64ArrayAttr:$lower_bounds, + DenseI64ArrayAttr:$upper_bounds); + let results = (outs); + let hasVerifier = 1; + + let assemblyFormat = [{ + $target `with` `[` $bounded_values `]` + `within` $lower_bounds `and` $upper_bounds attr-dict + }]; +} + #endif // Affine_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -3255,7 +3255,6 @@ AffineMap map = affineOp.getAffineMap(); if (failed(canonicalizeMapExprAndTermOrder(map))) return failure(); - rewriter.replaceOpWithNewOp(affineOp, map, affineOp.getMapOperands()); return success(); } diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -7,13 +7,141 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// SimplifyBoundedAffineOpsOp +//===----------------------------------------------------------------------===// + +LogicalResult SimplifyBoundedAffineOpsOp::verify() { + if (getLowerBounds().size() != getBoundedValues().size()) + return emitOpError() << "incorrect number of lower bounds, expected " + << getBoundedValues().size() << " but found " + << getLowerBounds().size(); + if (getUpperBounds().size() != getBoundedValues().size()) + return emitOpError() << "incorrect number of upper bounds, expected " + << getBoundedValues().size() << " but found " + << getUpperBounds().size(); + return success(); +} + +namespace { +/// Simplify affine.min / affine.max ops with the given constraints. They are +/// either rewritten to affine.apply or left unchanged. +template +struct SimplifyAffineMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + SimplifyAffineMinMaxOp(MLIRContext *ctx, + const FlatAffineValueConstraints &constraints, + PatternBenefit benefit = 1) + : OpRewritePattern(ctx, benefit), constraints(constraints) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + FailureOr simplified = + simplifyConstrainedMinMaxOp(op, constraints); + if (failed(simplified)) + return failure(); + rewriter.replaceOpWithNewOp(op, simplified->getAffineMap(), + simplified->getOperands()); + return success(); + } + + const FlatAffineValueConstraints &constraints; +}; +} // namespace + +DiagnosedSilenceableFailure +SimplifyBoundedAffineOpsOp::apply(TransformResults &results, + TransformState &state) { + // Get constraints for bounded values. + SmallVector lbs; + SmallVector ubs; + SmallVector boundedValues; + DenseSet boundedOps; + for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), + getUpperBounds())) { + Value handle = std::get<0>(it); + ArrayRef boundedValueOps = state.getPayloadOps(handle); + for (Operation *op : boundedValueOps) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + auto diag = + emitDefiniteFailure() + << "expected bounded value handle to point to one or multiple " + "single-result index-typed ops"; + diag.attachNote(op->getLoc()) << "multiple/non-index result"; + return diag; + } + boundedValues.push_back(op->getResult(0)); + boundedOps.insert(op); + lbs.push_back(std::get<1>(it)); + ubs.push_back(std::get<2>(it)); + } + } + + // Build constraint set. + FlatAffineValueConstraints cstr; + for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) { + unsigned pos; + if (!cstr.findVar(std::get<0>(it), &pos)) + pos = cstr.appendSymbolVar(std::get<0>(it)); + cstr.addBound(FlatAffineValueConstraints::BoundType::LB, pos, + std::get<1>(it)); + // Note: addBound bounds are inclusive, but specified UB is exclusive. + cstr.addBound(FlatAffineValueConstraints::BoundType::UB, pos, + std::get<2>(it) - 1); + } + + // Transform all targets. + ArrayRef targets = state.getPayloadOps(getTarget()); + for (Operation *target : targets) { + if (!isa(target)) { + auto diag = emitDefiniteFailure() + << "target must be affine.min or affine.max"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + if (boundedOps.contains(target)) { + auto diag = emitDefiniteFailure() + << "target op result must not be constrainted"; + diag.attachNote(target->getLoc()) << "target/constrained op"; + return diag; + } + } + SmallVector transformed; + RewritePatternSet patterns(getContext()); + // Canonicalization patterns are needed so that affine.apply ops are composed + // with the remaining affine.min/max ops. + AffineMaxOp::getCanonicalizationPatterns(patterns, getContext()); + AffineMinOp::getCanonicalizationPatterns(patterns, getContext()); + patterns.insert, + SimplifyAffineMinMaxOp>(getContext(), cstr); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + // Apply the simplification pattern to a fixpoint. + (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true); + return DiagnosedSilenceableFailure::success(); +} + +void SimplifyBoundedAffineOpsOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + for (Value v : getBoundedValues()) + onlyReadsHandle(v, effects); + modifiesPayload(effects); +} //===----------------------------------------------------------------------===// // Transform op registration diff --git a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRAffineTransformOpsIncGen LINK_LIBS PUBLIC + MLIRAffineAnalysis MLIRAffineDialect MLIRFuncDialect MLIRIR diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect \ +// RUN: --test-transform-dialect-interpreter -verify-diagnostics \ +// RUN: --split-input-file | FileCheck %s + +// CHECK: func @simplify_min_max() +// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index +// CHECK-DAG: %[[c100:.*]] = arith.constant 100 : index +// CHECK: return %[[c50]], %[[c100]] +func.func @simplify_min_max() -> (index, index) { + %0 = "test.some_op"() : () -> (index) + %1 = affine.min affine_map<()[s0] -> (50, 100 - s0)>()[%0] + %2 = affine.max affine_map<()[s0] -> (100, 80 + s0)>()[%0] + return %1, %2 : index, index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1 + %1 = transform.structured.match ops{["test.some_op"]} in %arg1 + transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [20] +} + +// ----- + +// CHECK: func @simplify_min_sequence() +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @simplify_min_sequence() -> index { + %1 = "test.workgroup_id"() : () -> (index) + %2 = affine.min affine_map<()[s0] -> (s0 * -32 + 1023, 32)>()[%1] + %3 = "test.thread_id"() : () -> (index) + %4 = affine.min affine_map<()[s0, s1] -> (s0 - s1 * (s0 ceildiv 32), s0 ceildiv 32)>()[%2, %3] + return %4 : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min"]} in %arg1 + %1 = transform.structured.match ops{["test.workgroup_id"]} in %arg1 + %2 = transform.structured.match ops{["test.thread_id"]} in %arg1 + transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 0] and [31, 31] +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min"]} in %arg1 + // expected-error@+1 {{incorrect number of lower bounds, expected 0 but found 1}} + transform.affine.simplify_bounded_affine_ops %0 with [] within [0] and [] +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min"]} in %arg1 + // expected-error@+1 {{incorrect number of upper bounds, expected 0 but found 1}} + transform.affine.simplify_bounded_affine_ops %0 with [] within [] and [5] +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1201,13 +1201,16 @@ hdrs = glob(["include/mlir/Dialect/Affine/TransformOps/*.h"]), includes = ["include"], deps = [ + ":AffineAnalysis", ":AffineDialect", ":AffineTransformOpsIncGen", ":AffineTransforms", ":AffineUtils", ":FuncDialect", ":IR", + ":PDLDialect", ":TransformDialect", + ":Transforms", ":VectorDialect", ], )