Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -621,11 +621,6 @@ /// Get the number of dimensions. unsigned getNumDims(); - operand_range getLowerBoundsOperands(); - operand_range getUpperBoundsOperands(); - - AffineValueMap getLowerBoundsValueMap(); - AffineValueMap getUpperBoundsValueMap(); AffineValueMap getRangesValueMap(); /// Get ranges as constants, may fail in dynamic case. @@ -636,6 +631,17 @@ MutableArrayRef getIVs() { return getBody()->getArguments(); } + + operand_range getLowerBoundsOperands(); + AffineValueMap getLowerBoundsValueMap(); + void setLowerBounds(ValueRange operands, AffineMap map); + void setLowerBoundsMap(AffineMap map); + + operand_range getUpperBoundsOperands(); + AffineValueMap getUpperBoundsValueMap(); + void setUpperBounds(ValueRange operands, AffineMap map); + void setUpperBoundsMap(AffineMap map); + void setSteps(ArrayRef newSteps); static StringRef getReductionsAttrName() { return "reductions"; } @@ -643,6 +649,9 @@ static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; } static StringRef getStepsAttrName() { return "steps"; } }]; + + let hasCanonicalizer = 1; + let hasFolder = 1; } def AffinePrefetchOp : Affine_Op<"prefetch"> { Index: mlir/include/mlir/Dialect/Affine/IR/AffineValueMap.h =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineValueMap.h +++ mlir/include/mlir/Dialect/Affine/IR/AffineValueMap.h @@ -74,6 +74,9 @@ ArrayRef getOperands() const; AffineMap getAffineMap() const; + /// Return success if the map and/or operands have been modified. + LogicalResult canonicalize(); + private: // A mutable affine map. MutableAffineMap map; Index: mlir/lib/Dialect/Affine/IR/AffineOps.cpp =================================================================== --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -17,6 +17,8 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -2506,8 +2508,47 @@ return OpBuilder(getBody(), std::prev(getBody()->end())); } +void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bounds map has at least one result"); + + auto ubOperands = getUpperBoundsOperands(); + + SmallVector newOperands(lbOperands); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getLowerBoundsMapAttrName(), AffineMapAttr::get(map)); +} + +void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bounds map has at least one result"); + + SmallVector newOperands(getLowerBoundsOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getUpperBoundsMapAttrName(), AffineMapAttr::get(map)); +} + +void AffineParallelOp::setLowerBoundsMap(AffineMap map) { + auto lbMap = lowerBoundsMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + (void)lbMap; + setAttr(getLowerBoundsMapAttrName(), AffineMapAttr::get(map)); +} + +void AffineParallelOp::setUpperBoundsMap(AffineMap map) { + auto ubMap = upperBoundsMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + (void)ubMap; + setAttr(getUpperBoundsMapAttrName(), AffineMapAttr::get(map)); +} + void AffineParallelOp::setSteps(ArrayRef newSteps) { - assert(newSteps.size() == getNumDims() && "steps & num dims mismatch"); setAttr(getStepsAttrName(), getBodyBuilder().getI64ArrayAttr(newSteps)); } @@ -2542,6 +2583,216 @@ return success(); } +namespace { +/// This pattern removes affine.parallel ops with no induction variables. +struct AffineParallelRank0LoopRemover + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + // Check that there are no induction variables + if (op.getNumDims()) + return failure(); + + // Only remove ops that don't have any custom attributes (i.e. those not + // defined by the op itself). + StringSet<> opAttrs{AffineParallelOp::getReductionsAttrName(), + AffineParallelOp::getLowerBoundsMapAttrName(), + AffineParallelOp::getUpperBoundsMapAttrName(), + AffineParallelOp::getStepsAttrName()}; + for (auto attr : op.getAttrs()) { + if (!opAttrs.count(attr.first.strref())) + return failure(); + } + + // Remove the affine.parallel wrapper, retain the body in the same location + auto &parentOps = rewriter.getInsertionBlock()->getOperations(); + auto ¶llelBodyOps = op.region().front().getOperations(); + auto yield = mlir::cast(std::prev(parallelBodyOps.end())); + for (auto it : zip(op.getResults(), yield.operands())) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + parentOps.splice(mlir::Block::iterator(op), parallelBodyOps, + parallelBodyOps.begin(), std::prev(parallelBodyOps.end())); + rewriter.eraseOp(op); + return success(); + } +}; + +/// This pattern removes indexes that go over an empty range. +struct AffineParallelTripCount1IndexRemover + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + auto ranges = op.getRangesValueMap(); + auto *body = op.getBody(); + SmallVector newLowerBounds; + SmallVector newUpperBounds; + SmallVector newSteps; + SmallVector argsToRemove; + for (unsigned i = 0, e = body->getNumArguments(); i < e; i++) { + // Is the range a constant value matching the step size? + auto constExpr = ranges.getResult(i).dyn_cast(); + int64_t step = op.steps()[i].template cast().getInt(); + if (constExpr && constExpr.getValue() == step) { + // Mark argument for removal and replacement with 0. + argsToRemove.push_back(body->getArgument(i)); + } else { + // Keep argument + newLowerBounds.push_back(op.lowerBoundsMap().getResult(i)); + newUpperBounds.push_back(op.upperBoundsMap().getResult(i)); + newSteps.push_back(step); + } + } + + // If no arguments need removal, return failure to match. + if (argsToRemove.empty()) + return failure(); + + // After this point, there will be no need to rollback the rewriter. + for (auto arg : argsToRemove) { + auto argNumber = arg.getArgNumber(); + auto lowerBoundValue = rewriter.create( + op.getLoc(), op.lowerBoundsMap().getSubMap({argNumber}), + op.getLowerBoundsOperands()); + arg.replaceAllUsesWith(lowerBoundValue); + body->eraseArgument(argNumber); + } + + // Update attributes and return success + auto newLower = AffineMap::get(op.lowerBoundsMap().getNumDims(), + op.lowerBoundsMap().getNumSymbols(), + newLowerBounds, op.getContext()); + auto newUpper = AffineMap::get(op.upperBoundsMap().getNumDims(), + op.upperBoundsMap().getNumSymbols(), + newUpperBounds, op.getContext()); + op.setAttr(AffineParallelOp::getLowerBoundsMapAttrName(), + AffineMapAttr::get(newLower)); + op.setAttr(AffineParallelOp::getUpperBoundsMapAttrName(), + AffineMapAttr::get(newUpper)); + op.setSteps(newSteps); + return success(); + } +}; + +struct SimplifyAffineParallel : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + + auto stepsAttrs = op.steps(); + auto lbMap = op.lowerBoundsMap(); + + SmallVector steps; + bool isWorkPending = false; + for (unsigned i = 0, e = stepsAttrs.size(); i < e; ++i) { + auto step = stepsAttrs[i].cast().getInt(); + steps.push_back(step); + auto lbExpr = lbMap.getResult(i).dyn_cast(); + isWorkPending |= (!lbExpr || lbExpr.getValue() || step != 1); + } + + // No need to do any work if the parallel op is already simplified. + if (!isWorkPending) + return failure(); + + auto ranges = op.getRangesValueMap(); + auto zeroExpr = rewriter.getAffineConstantExpr(0); + rewriter.setInsertionPointToStart(op.getBody()); + SmallVector lbExprs; + SmallVector ubExprs; + for (unsigned i = 0, e = steps.size(); i < e; ++i) { + auto step = steps[i]; + + // Adjust the lower bound to be 0. + lbExprs.push_back(zeroExpr); + + // Adjust the upper bound expression: 'range / step' + auto ubExpr = ranges.getResult(i).floorDiv(step); + ubExprs.push_back(ubExpr); + + // Adjust the corresponding IV: 'lb + i * step' + auto iv = op.getBody()->getArgument(i); + auto lbExpr = lbMap.getResult(i); + auto nDims = lbMap.getNumDims(); + auto expr = lbExpr + rewriter.getAffineDimExpr(nDims) * step; + auto map = AffineMap::get(/*dimCount=*/nDims + 1, + /*symbolCount=*/lbMap.getNumSymbols(), expr); + + // Use an 'affine.apply' op that will be simplified later in subsequent + // canonicalizations. + auto lbOperands = op.getLowerBoundsOperands(); + auto dimOperands = lbOperands.take_front(nDims); + auto symbolOperands = lbOperands.drop_front(nDims); + SmallVector applyOperands{dimOperands}; + applyOperands.push_back(iv); + applyOperands.append(symbolOperands.begin(), symbolOperands.end()); + auto apply = + rewriter.create(op.getLoc(), map, applyOperands); + iv.replaceAllUsesExcept(apply, SmallPtrSet{apply}); + } + + SmallVector newSteps(op.getNumDims(), 1); + op.setSteps(newSteps); + auto newLowerMap = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, + lbExprs, rewriter.getContext()); + op.setLowerBounds({}, newLowerMap); + auto newUpperMap = + AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(), ubExprs, + rewriter.getContext()); + op.setUpperBounds(ranges.getOperands(), newUpperMap); + + return success(); + } +}; + +} // end anonymous namespace + +LogicalResult AffineValueMap::canonicalize() { + SmallVector newOperands{operands}; + auto newMap = getAffineMap(); + composeAffineMapAndOperands(&newMap, &newOperands); + if (newMap == getAffineMap() && newOperands == operands) + return failure(); + reset(newMap, newOperands); + return success(); +} + +/// Canonicalize the bounds of the given loop. +static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { + auto lb = op.getLowerBoundsValueMap(); + auto lbCanonicalized = succeeded(lb.canonicalize()); + + auto ub = op.getUpperBoundsValueMap(); + auto ubCanonicalized = succeeded(ub.canonicalize()); + + // Any canonicalization change always leads to updated map(s). + if (!lbCanonicalized && !ubCanonicalized) + return failure(); + + if (lbCanonicalized) + op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); + if (ubCanonicalized) + op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); + + return success(); +} + +void AffineParallelOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +LogicalResult AffineParallelOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + return canonicalizeLoopBounds(*this); +} + static void print(OpAsmPrinter &p, AffineParallelOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), @@ -2642,7 +2893,7 @@ } // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the - // quoted strings a member of the enum AtomicRMWKind. + // quoted strings are a member of the enum AtomicRMWKind. SmallVector reductions; if (succeeded(parser.parseOptionalKeyword("reduce"))) { if (parser.parseLParen()) Index: mlir/test/Dialect/Affine/affine-fold.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Affine/affine-fold.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -canonicalize -split-input-file %s | FileCheck %s + +// CHECK: func @affine_parallel_rank0 +func @affine_parallel_rank0(%out: memref) { + // CHECK-NEXT: constant + %cst = constant 0.0 : f32 + // CHECK-NEXT: affine.store + affine.parallel () = () to () { + affine.parallel () = () to () { + affine.store %cst, %out[] : memref + } + } + return +} + +// ----- + +// CHECK-LABEL: func @affine_parallel_range1 +func @affine_parallel_range1() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref<2x4xf32> + // CHECK-NEXT: affine.store + affine.parallel (%i, %j) = (0, 1) to (2, 2) step (2, 1) { + affine.store %cst, %0[%i, %j] : memref<2x4xf32> + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @affine_parallel_partial_range1 +func @affine_parallel_partial_range1() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref<2x4xf32> + // CHECK-NEXT: affine.parallel (%{{.*}}) = (0) to (10) + affine.parallel (%i, %j) = (0, 1) to (10, 2) { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 1] + affine.store %cst, %0[%i, %j] : memref<2x4xf32> + } + // CHECK: return + return +} + +// ----- + +// CHECK-LABEL: func @simplify_parallel +func @simplify_parallel() { + %cst = constant 1.0 : f32 + %0 = alloc() : memref<2x4xf32> + // CHECK: affine.parallel (%[[i:.*]], %[[j:.*]]) = (0, 0) to (3, 2) { + affine.parallel (%i, %j) = (0, 1) to (10, 5) step (3, 2) { + // CHECK: affine.parallel (%[[k:.*]]) = (0) to (%[[j]] * 2 - %[[i]] * 3 + 1) { + affine.parallel (%k) = (%i) to (%j) { + // CHECK: affine.store %{{.*}}, %{{.*}}[%[[i]] * 3, %[[i]] * 3 + %[[k]]] : memref<2x4xf32> + affine.store %cst, %0[%i, %k] : memref<2x4xf32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @affine_parallel_const_bounds +func @affine_parallel_const_bounds() { + %cst = constant 1.0 : f32 + %c0 = constant 0 : index + %c4 = constant 4 : index + %0 = alloc() : memref<4xf32> + // CHECK: affine.parallel (%{{.*}}) = (0) to (4) + affine.parallel (%i) = (%c0) to (%c0 + %c4) { + affine.store %cst, %0[%i] : memref<4xf32> + } + return +} + +// ----- + +#map0 = affine_map<(d0) -> (d0 * 5)> +#map1 = affine_map<(d0) -> (d0 * 10)> + +// CHECK-LABEL: func @affine_parallel_fold_bounds +func @affine_parallel_fold_bounds() { + %cst = constant 1.0 : f32 + %0 = alloc() : memref<100x100xf32> + // CHECK: affine.parallel (%[[i0:.*]], %[[j0:.*]]) = + affine.parallel (%i0, %j0) = (0, 0) to (100, 10) { + %2 = affine.apply #map0(%i0) + %3 = affine.apply #map1(%j0) + // CHECK-NOT: affine.apply + // CHECK: affine.parallel (%[[i1:.*]], %[[j1:.*]]) = (0, 0) to (5, 10) { + affine.parallel (%i1, %j1) = (%2, %3) to (%2 + 5, %3 + 10) { + // CHECK: affine.store %{{.*}}, %{{.*}}[%[[i0]] * 5 + %[[i1]], %[[j0]] * 10 + %[[j1]]] + affine.store %cst, %0[%i1, %j1] : memref<100x100xf32> + } + } + return +}