diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H #define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td @@ -9,6 +9,7 @@ #ifndef AFFINE_TRANSFORM_OPS #define AFFINE_TRANSFORM_OPS +include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" @@ -18,4 +19,41 @@ def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">; +def SimplifyBoundedAffineOpsOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Simplify the targeted affine.min / affine.max ops given the supplied + lower and upper bounds for values that may be used as target op operands. + + Example: + ``` + %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1 + %1 = transform.structured.match ops{["gpu.lane_id"]} in %arg1 + // Multiple bounds can be specified. + transform.affine.simplify_bounded_affine_ops %0 with {%1 in (0, 32)} + ``` + + Bounded op handles (`%1`) must be mapped to ops that have a single result of + index type. The sets of target ops and bounded ops must not overlap. + + #### Return modes + + Target ops must be affine.min or affine.max ops. This transform consumes the + target handle and does not produce any handle. It reads the bounded op + handles. + + TODO: Support affine.apply targets. + TODO: Allow mixed PDL_Operation/int64_t for lower_bounds and upper_bounds. + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$bounded_values, + DenseI64ArrayAttr:$lower_bounds, + DenseI64ArrayAttr:$upper_bounds); + let results = (outs); + let hasCustomAssemblyFormat = 1; +} + #endif // Affine_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -7,13 +7,190 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// SimplifyBoundedAffineOpsOp +//===----------------------------------------------------------------------===// + +namespace { +/// Simplify affine.min / affine.max ops with the given constraints. They are +/// either rewritten to affine.apply or left unchanged. +template +struct SimplifyAffineMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + SimplifyAffineMinMaxOp(MLIRContext *ctx, + const FlatAffineValueConstraints &constraints, + PatternBenefit benefit = 1) + : OpRewritePattern(ctx, benefit), constraints(constraints) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + FailureOr simplified = + simplifyConstrainedMinMaxOp(op, constraints); + if (failed(simplified)) + return failure(); + rewriter.replaceOpWithNewOp(op, simplified->getAffineMap(), + simplified->getOperands()); + return success(); + } + + const FlatAffineValueConstraints &constraints; +}; +} // namespace + +DiagnosedSilenceableFailure +SimplifyBoundedAffineOpsOp::apply(TransformResults &results, + TransformState &state) { + // Get constraints for bounded values. + SmallVector lbs; + SmallVector ubs; + SmallVector boundedValues; + DenseSet boundedOps; + for (const auto &it : + llvm::zip(getBoundedValues(), getLowerBounds(), getUpperBounds())) { + Value handle = std::get<0>(it); + ArrayRef boundedValueOps = state.getPayloadOps(handle); + for (Operation *op : boundedValueOps) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + auto diag = + emitDefiniteFailure() + << "expected bounded value handle to point to one or multiple " + "single-result index-typed ops"; + diag.attachNote(op->getLoc()) << "multiple/non-index result"; + return diag; + } + boundedValues.push_back(op->getResult(0)); + boundedOps.insert(op); + lbs.push_back(std::get<1>(it)); + ubs.push_back(std::get<2>(it)); + } + } + + // Build constraint set. + FlatAffineValueConstraints cstr; + for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) { + unsigned pos; + if (!cstr.findVar(std::get<0>(it), &pos)) + pos = cstr.appendSymbolVar(std::get<0>(it)); + cstr.addBound(FlatAffineValueConstraints::BoundType::LB, pos, + std::get<1>(it)); + // Note: addBound bounds are inclusive, but specified UB is exclusive. + cstr.addBound(FlatAffineValueConstraints::BoundType::UB, pos, + std::get<2>(it) - 1); + } + + // Transform all targets. + ArrayRef targets = state.getPayloadOps(getTarget()); + for (Operation *target : targets) { + if (!isa(target)) { + auto diag = emitDefiniteFailure() + << "target must be affine.min or affine.max"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + if (boundedOps.contains(target)) { + auto diag = emitDefiniteFailure() + << "target op result must not be constrainted"; + diag.attachNote(target->getLoc()) << "target/constrained op"; + return diag; + } + } + SmallVector transformed; + RewritePatternSet patterns(getContext()); + AffineMinOp::getCanonicalizationPatterns(patterns, getContext()); + patterns.insert, + SimplifyAffineMinMaxOp>(getContext(), cstr); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + // We do not compose chained affine.min / affine.max ops before simplifying + // them. Instead, we apply the simplification pattern to a fixpoint. + (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/false); + return DiagnosedSilenceableFailure::success(); +} + +void SimplifyBoundedAffineOpsOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + for (Value v : getBoundedValues()) + onlyReadsHandle(v, effects); + modifiesPayload(effects); +} + +ParseResult SimplifyBoundedAffineOpsOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand target; + auto pdlOperationType = + pdl::OperationType::get(parser.getBuilder().getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands)) + return failure(); + + SmallVector lowerBounds, upperBounds; + if (parser.parseOptionalKeyword("with").succeeded()) { + if (parser.parseLBrace()) + return failure(); + while (true) { + // Parse $bounded_value => (lb, ub). + OpAsmParser::UnresolvedOperand boundedValue; + int64_t lb, ub; + if (parser.parseOperand(boundedValue) || + parser.resolveOperand(boundedValue, pdlOperationType, + result.operands) || + parser.parseKeyword("in") || parser.parseLParen() || + parser.parseInteger(lb) || parser.parseComma() || + parser.parseInteger(ub) || parser.parseRParen()) + return failure(); + lowerBounds.push_back(lb); + upperBounds.push_back(ub); + if (parser.parseOptionalRBrace().succeeded()) + break; + if (parser.parseComma()) + return failure(); + } + } + result.addAttribute( + SimplifyBoundedAffineOpsOp::getLowerBoundsAttrName(result.name) + .getValue(), + parser.getBuilder().getDenseI64ArrayAttr(lowerBounds)); + result.addAttribute( + SimplifyBoundedAffineOpsOp::getUpperBoundsAttrName(result.name) + .getValue(), + parser.getBuilder().getDenseI64ArrayAttr(upperBounds)); + + if (failed(parser.parseOptionalAttrDict(result.attributes))) + return failure(); + + return success(); +} + +void SimplifyBoundedAffineOpsOp::print(OpAsmPrinter &p) { + p << " " << getTarget() << " "; + if (!getBoundedValues().empty()) { + p << "with {"; + for (int i = 0; i < getBoundedValues().size(); ++i) { + p << getBoundedValues()[i] << " in (" << getLowerBounds()[i] << ", " + << getUpperBounds()[i] << ")"; + if (i < getBoundedValues().size() - 1) + p << ", "; + } + p << "} "; + } + p.printOptionalAttrDict(getOperation()->getAttrs(), + {getLowerBoundsAttrName(), getUpperBoundsAttrName()}); +} //===----------------------------------------------------------------------===// // Transform op registration diff --git a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRAffineTransformOpsIncGen LINK_LIBS PUBLIC + MLIRAffineAnalysis MLIRAffineDialect MLIRFuncDialect MLIRIR diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK: func @simplify_min_max() +// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index +// CHECK-DAG: %[[c100:.*]] = arith.constant 100 : index +// CHECK: return %[[c50]], %[[c100]] +func.func @simplify_min_max() -> (index, index) { + %0 = "test.some_op"() : () -> (index) + %1 = affine.min affine_map<()[s0] -> (50, 100 - s0)>()[%0] + %2 = affine.max affine_map<()[s0] -> (100, 80 + s0)>()[%0] + return %1, %2 : index, index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1 + %1 = transform.structured.match ops{["test.some_op"]} in %arg1 + transform.affine.simplify_bounded_affine_ops %0 with {%1 in (0, 20)} +} + +// ----- + +// CHECK: func @simplify_min_sequence() +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @simplify_min_sequence() -> index { + %1 = "test.workgroup_id"() : () -> (index) + %2 = affine.min affine_map<()[s0] -> (s0 * -32 + 1023, 32)>()[%1] + %3 = "test.thread_id"() : () -> (index) + %4 = affine.min affine_map<()[s0, s1] -> (s0 - s1 * (s0 ceildiv 32), s0 ceildiv 32)>()[%2, %3] + return %4 : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.min"]} in %arg1 + %1 = transform.structured.match ops{["test.workgroup_id"]} in %arg1 + %2 = transform.structured.match ops{["test.thread_id"]} in %arg1 + transform.affine.simplify_bounded_affine_ops %0 with {%1 in (0, 31), %2 in (0, 31)} +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1201,13 +1201,16 @@ hdrs = glob(["include/mlir/Dialect/Affine/TransformOps/*.h"]), includes = ["include"], deps = [ + ":AffineAnalysis", ":AffineDialect", ":AffineTransformOpsIncGen", ":AffineTransforms", ":AffineUtils", ":FuncDialect", ":IR", + ":PDLDialect", ":TransformDialect", + ":Transforms", ":VectorDialect", ], )