diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -501,7 +501,8 @@ let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " "constructs."; let constructor = "mlir::createConvertSCFToOpenMPPass()"; - let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"]; + let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect", + "memref::MemRefDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -274,6 +274,7 @@ let regions = (region SizedRegion<1>:$bodyRegion); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -364,8 +365,6 @@ loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); - Value token = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || init.getType().isa()) && @@ -392,31 +391,31 @@ // Create the parallel wrapper. auto ompParallel = rewriter.create(loc); { + OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); - // Replace SCF yield with OpenMP yield. { - OpBuilder::InsertionGuard innerGuard(rewriter); - rewriter.setInsertionPointToEnd(parallelOp.getBody()); - assert(llvm::hasSingleElement(parallelOp.getRegion()) && - "expected scf.parallel to have one block"); - rewriter.replaceOpWithNewOp( - parallelOp.getBody()->getTerminator(), ValueRange()); - } - - // Replace the loop. - auto loop = rewriter.create( - parallelOp.getLoc(), parallelOp.getLowerBound(), - parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create(loc); - - rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), - loop.region().begin()); - if (!reductionVariables.empty()) { - loop.reductionsAttr( - ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); - loop.reduction_varsMutable().append(reductionVariables); + auto scope = rewriter.create(parallelOp.getLoc(), + TypeRange()); + rewriter.create(loc); + OpBuilder::InsertionGuard allocaGuard(rewriter); + rewriter.createBlock(&scope.getBodyRegion()); + rewriter.setInsertionPointToStart(&scope.getBodyRegion().front()); + + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.getLowerBound(), + parallelOp.getUpperBound(), parallelOp.getStep()); + rewriter.create(loc); + + rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), + loop.region().begin()); + if (!reductionVariables.empty()) { + loop.reductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + loop.reduction_varsMutable().append(reductionVariables); + } } } @@ -429,7 +428,6 @@ } rewriter.replaceOp(parallelOp, results); - rewriter.create(loc, token); return success(); } }; @@ -438,7 +436,8 @@ static LogicalResult applyPatterns(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -258,6 +259,156 @@ regions.push_back(RegionSuccessor(&bodyRegion())); } +/// Given an operation, return whether this op is guaranteed to +/// allocate an AutomaticAllocationScopeResource +static bool isGuaranteedAutomaticAllocationScope(Operation *op) { + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return false; + for (auto res : op->getResults()) { + if (auto effect = + interface.getEffectOnValue(res)) { + if (isa( + effect->getResource())) + return true; + } + } + return false; +} + +/// Given an operation, return whether this op could to +/// allocate an AutomaticAllocationScopeResource +static bool isPotentialAutomaticAllocationScope(Operation *op) { + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return true; + for (auto res : op->getResults()) { + if (auto effect = + interface.getEffectOnValue(res)) { + if (isa( + effect->getResource())) + return true; + } + } + return false; +} + +/// Return whether this op is the last non terminating block +/// in a region. That is to say, it is in a one-block region +/// and is only followed by a terminator. This prevents +/// extending the lifetime of allocations. +static bool lastNonTerminatorInRegion(Operation *op) { + return op->getNextNode() == op->getBlock()->getTerminator() || + op->getParentRegion()->getBlocks().size() == 1; +} + +/// Inline an AllocaScopeOp if either the direct parent is an allocation scope +/// or it contains no allocation. +struct AllocaScopeInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + if (!op->getParentOp()->hasTrait()) { + bool hasPotentialAlloca = + op->walk([&](Operation *alloc) { + if (isPotentialAutomaticAllocationScope(alloc)) + return WalkResult::interrupt(); + return WalkResult::skip(); + }).wasInterrupted(); + if (hasPotentialAlloca) + return failure(); + } + + // Only apply to if this is this last non-terminator + // op in the block (lest lifetime be extended) of a one + // block region + if (!lastNonTerminatorInRegion(op)) + return failure(); + + Block *block = &op.getRegion().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + return success(); + } +}; + +/// Move allocations into an allocation scope, if it is legal to +/// move them (e.g. their operands are available at the location +/// the op would be moved to). +struct AllocaScopeHoister : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + if (op->getParentOp()->hasTrait()) + return failure(); + + if (!op->getParentWithTrait()) + return failure(); + + Operation *lastParentWithoutScope = op->getParentOp(); + + // Only apply to if this is this last non-terminator + // op in the block (lest lifetime be extended) of a one + // block region + if (!lastNonTerminatorInRegion(op) || + !lastNonTerminatorInRegion(lastParentWithoutScope)) + return failure(); + + while (!lastParentWithoutScope->getParentOp() + ->hasTrait()) { + lastParentWithoutScope = lastParentWithoutScope->getParentOp(); + if (!lastNonTerminatorInRegion(lastParentWithoutScope)) + return failure(); + } + Operation *scope = lastParentWithoutScope->getParentOp(); + assert(scope->hasTrait()); + + Region *containingRegion = nullptr; + for (auto &r : lastParentWithoutScope->getRegions()) { + if (r.isAncestor(op->getParentRegion())) { + assert(containingRegion == nullptr && + "only one region can contain the op"); + containingRegion = &r; + } + } + assert(containingRegion && "op must be contained in a region"); + + SmallVector toHoist; + op->walk([&](Operation *alloc) { + if (!isGuaranteedAutomaticAllocationScope(alloc)) + return WalkResult::skip(); + + // If any operand is not defined before the location of + // lastParentWithoutScope (i.e. where we would hoist to), skip. + if (llvm::any_of(alloc->getOperands(), [&](Value v) { + return containingRegion->isAncestor(v.getParentRegion()); + })) + return WalkResult::skip(); + toHoist.push_back(alloc); + return WalkResult::advance(); + }); + + if (!toHoist.size()) + return failure(); + rewriter.setInsertionPoint(lastParentWithoutScope); + for (auto op : toHoist) { + auto cloned = rewriter.clone(*op); + rewriter.replaceOp(op, cloned->getResults()); + } + return success(); + } +}; + +void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -21,12 +21,12 @@ %arg3 : index, %arg4 : index) { // CHECK: %[[CST:.*]] = arith.constant 0.0 // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 - // CHECK: llvm.intr.stacksave // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32 // CHECK: llvm.store %[[CST]], %[[BUF]] %step = arith.constant 1 : index %zero = arith.constant 0.0 : f32 // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]] scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) @@ -43,7 +43,6 @@ } // CHECK: omp.terminator // CHECK: llvm.load %[[BUF]] - // CHECK: llvm.intr.stackrestore return } @@ -162,6 +161,7 @@ // CHECK: llvm.store %[[IONE]], %[[BUF2]] // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -4,6 +4,7 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () @@ -20,9 +21,11 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () @@ -41,6 +44,7 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> () @@ -52,6 +56,7 @@ // CHECK: } // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> () diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -552,3 +552,94 @@ // CHECK-LABEL: func @self_copy // CHECK-NEXT: return + +// ----- + +func @scopeMerge() { + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + return +} +// CHECK: func @scopeMerge() { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +func @scopeMerge2() { + "test.region"() ({ + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge2() { +// CHECK-NEXT: "test.region"() ({ +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: "test.terminator"() : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +func @scopeMerge3() { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK-NEXT: func @scopeMerge3() { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.region"() ({ +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: "test.terminator"() : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +func @scopeMerge4() { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.op"() : () -> () + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK-NEXT: func @scopeMerge4() { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: "test.region"() ({ +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: "test.terminator"() : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: }