diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -234,7 +234,20 @@ let hasFolder = 1; } -def AffineMinOp : Affine_Op<"min"> { +class AffineMinMaxOpBase traits = []> : + Op { + let arguments = (ins AffineMapAttr:$map, Variadic:$operands); + let results = (outs Index); + let extraClassDeclaration = [{ + static StringRef getMapAttrName() { return "map"; } + }]; + let verifier = [{ return ::verifyAffineMinMaxOp(*this); }]; + let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; + let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; + let hasFolder = 1; +} + +def AffineMinOp : AffineMinMaxOpBase<"min"> { let summary = "min operation"; let description = [{ The "min" operation computes the minimum value result from a multi-result @@ -244,12 +257,18 @@ %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index }]; - let arguments = (ins AffineMapAttr:$map, Variadic:$operands); - let results = (outs Index); - let extraClassDeclaration = [{ - static StringRef getMapAttrName() { return "map"; } +} + +def AffineMaxOp : AffineMinMaxOpBase<"max"> { + let summary = "max operation"; + let description = [{ + The "max" operation computes the maximum value result from a multi-result + affine map. + + Example: + + %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) : index }]; - let hasFolder = 1; } def AffinePrefetchOp : Affine_Op<"prefetch"> { diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -196,6 +196,7 @@ using affine_if = OperationBuilder; using affine_load = ValueBuilder; using affine_min = ValueBuilder; +using affine_max = ValueBuilder; using affine_store = OperationBuilder; using alloc = ValueBuilder; using call = OperationBuilder; diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -258,16 +258,13 @@ return value; } -/// Emit instructions that correspond to the affine map in the lower bound -/// applied to the respective operands, and compute the maximum value across -/// the results. -Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { - auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), - op.getLowerBoundOperands()); - if (!lbValues) - return nullptr; - return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues, - builder); +/// Emit instructions that correspond to computing the maximum value amoung the +/// values of a (potentially) multi-output affine map applied to `operands`. +static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, + ValueRange operands) { + if (auto values = expandAffineMap(builder, loc, map, operands)) + return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder); + return nullptr; } /// Emit instructions that correspond to computing the minimum value amoung the @@ -287,6 +284,14 @@ op.getUpperBoundOperands()); } +/// Emit instructions that correspond to the affine map in the lower bound +/// applied to the respective operands, and compute the maximum value across +/// the results. +Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { + return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(), + op.getLowerBoundOperands()); +} + namespace { class AffineMinLowering : public OpRewritePattern { public: @@ -304,6 +309,22 @@ } }; +class AffineMaxLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineMaxOp op, + PatternRewriter &rewriter) const override { + Value reduced = + lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands()); + if (!reduced) + return matchFailure(); + + rewriter.replaceOp(op, reduced); + return matchSuccess(); + } +}; + /// Affine terminators are removed. class AffineTerminatorLowering : public OpRewritePattern { public: @@ -546,6 +567,7 @@ AffineDmaWaitLowering, AffineLoadLowering, AffineMinLowering, + AffineMaxLowering, AffinePrefetchLowering, AffineStoreLowering, AffineForLowering, diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -1935,22 +1935,41 @@ } //===----------------------------------------------------------------------===// -// AffineMinOp +// AffineMinMaxOpBase //===----------------------------------------------------------------------===// -// -// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) -// -static ParseResult parseAffineMinOp(OpAsmParser &parser, - OperationState &result) { +template +static LogicalResult verifyAffineMinMaxOp(T op) { + // Verify that operand count matches affine map dimension and symbol count. + if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) + return op.emitOpError( + "operand count and affine map dimension and symbol count must match"); + return success(); +} + +template +static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { + p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName()); + auto operands = op.getOperands(); + unsigned numDims = op.map().getNumDims(); + p << '(' << operands.take_front(numDims) << ')'; + + if (operands.size() != numDims) + p << '[' << operands.drop_front(numDims) << ']'; + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{T::getMapAttrName()}); +} + +template +static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); SmallVector dim_infos; SmallVector sym_infos; AffineMapAttr mapAttr; return failure( - parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(), - result.attributes) || + parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) || parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) || parser.parseOperandList(sym_infos, OpAsmParser::Delimiter::OptionalSquare) || @@ -1960,25 +1979,12 @@ parser.addTypeToList(indexType, result.types)); } -static void print(OpAsmPrinter &p, AffineMinOp op) { - p << op.getOperationName() << ' ' - << op.getAttr(AffineMinOp::getMapAttrName()); - auto operands = op.getOperands(); - unsigned numDims = op.map().getNumDims(); - p << '(' << operands.take_front(numDims) << ')'; - - if (operands.size() != numDims) - p << '[' << operands.drop_front(numDims) << ']'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); -} - -static LogicalResult verify(AffineMinOp op) { - // Verify that operand count matches affine map dimension and symbol count. - if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) - return op.emitOpError( - "operand count and affine map dimension and symbol count must match"); - return success(); -} +//===----------------------------------------------------------------------===// +// AffineMinOp +//===----------------------------------------------------------------------===// +// +// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) +// OpFoldResult AffineMinOp::fold(ArrayRef operands) { // Fold the affine map. @@ -2004,6 +2010,36 @@ } //===----------------------------------------------------------------------===// +// AffineMaxOp +//===----------------------------------------------------------------------===// +// +// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) +// + +OpFoldResult AffineMaxOp::fold(ArrayRef operands) { + // Fold the affine map. + // TODO(andydavis, ntv, ouhang) Fold more cases: partial static information, + // max(some_affine, some_affine + constant, ...). + SmallVector results; + if (failed(map().constantFold(operands, results))) + return {}; + + // Compute and return max of folded map results. + int64_t max = std::numeric_limits::min(); + int maxIndex = -1; + for (unsigned i = 0, e = results.size(); i < e; ++i) { + auto intAttr = results[i].cast(); + if (intAttr.getInt() > max) { + max = intAttr.getInt(); + maxIndex = i; + } + } + if (maxIndex < 0) + return {}; + return results[maxIndex]; +} + +//===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -526,3 +526,29 @@ // CHECK-NEXT: return return } + +// ----- + +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + %c511 = constant 511 : index + %c1 = constant 0 : index + %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 + 512, s0 + 1)> (%c1)[%c511] + "op0"(%0) : (index) -> () + // CHECK: %[[CST:.*]] = constant 1000 : index + // CHECK-NEXT: "op0"(%[[CST]]) : (index) -> () + // CHECK-NEXT: return + return +} + +// ----- + +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + %c3 = constant 3 : index + %c20 = constant 20 : index + %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 floordiv 4, (s0 mod 5) + 1)> (%c20)[%c3] + "op0"(%0) : (index) -> () + // CHECK: %[[CST:.*]] = constant 1000 : index + // CHECK-NEXT: "op0"(%[[CST]]) : (index) -> () + // CHECK-NEXT: return + return +} \ No newline at end of file diff --git a/mlir/test/AffineOps/invalid.mlir b/mlir/test/AffineOps/invalid.mlir --- a/mlir/test/AffineOps/invalid.mlir +++ b/mlir/test/AffineOps/invalid.mlir @@ -168,3 +168,33 @@ return } + +// ----- + +// CHECK-LABEL: @affine_max +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{operand count and affine map dimension and symbol count must match}} + %0 = affine.max affine_map<(d0) -> (d0)> (%arg0, %arg1) + + return +} + +// ----- + +// CHECK-LABEL: @affine_max +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{operand count and affine map dimension and symbol count must match}} + %0 = affine.max affine_map<()[s0] -> (s0)> (%arg0, %arg1) + + return +} + +// ----- + +// CHECK-LABEL: @affine_max +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{operand count and affine map dimension and symbol count must match}} + %0 = affine.max affine_map<(d0) -> (d0)> () + + return +} \ No newline at end of file diff --git a/mlir/test/AffineOps/ops.mlir b/mlir/test/AffineOps/ops.mlir --- a/mlir/test/AffineOps/ops.mlir +++ b/mlir/test/AffineOps/ops.mlir @@ -79,6 +79,19 @@ return } +// CHECK-LABEL: @affine_max +func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: affine.max #[[MAP0]](%arg0)[%arg1] + %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 + 512, s0)> (%arg0)[%arg1] + // CHECK: affine.max #[[MAP1]](%arg0, %arg1)[%arg2] + %1 = affine.max affine_map<(d0, d1)[s0] -> (d0 - d1, s0 + 512)> (%arg0, %arg1)[%arg2] + // CHECK: affine.max #[[MAP2]]()[%arg1, %arg2] + %2 = affine.max affine_map<()[s0, s1] -> (s0 - s1, 11)> ()[%arg1, %arg2] + // CHECK: affine.max #[[MAP3]]() + %3 = affine.max affine_map<()[] -> (77, 78, 79)> ()[] + return +} + // ----- func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) { diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -605,3 +605,18 @@ %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1) return %0 : index } + +// CHECK-LABEL: func @affine_max +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index +func @affine_max(%arg0: index, %arg1: index) -> index{ + // CHECK: %[[Cm1:.*]] = constant -1 + // CHECK: %[[neg1:.*]] = muli %[[ARG1]], %[[Cm1:.*]] + // CHECK: %[[first:.*]] = addi %[[ARG0]], %[[neg1]] + // CHECK: %[[Cm2:.*]] = constant -1 + // CHECK: %[[neg2:.*]] = muli %[[ARG0]], %[[Cm2:.*]] + // CHECK: %[[second:.*]] = addi %[[ARG1]], %[[neg2]] + // CHECK: %[[cmp:.*]] = cmpi "sgt", %[[first]], %[[second]] + // CHECK: select %[[cmp]], %[[first]], %[[second]] + %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1) + return %0 : index +} \ No newline at end of file