diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -501,7 +501,7 @@ let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " "constructs."; let constructor = "mlir::createConvertSCFToOpenMPPass()"; - let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"]; + let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect", "memref::MemRefDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -274,6 +274,7 @@ let regions = (region SizedRegion<1>:$bodyRegion); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -364,8 +365,6 @@ loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); - Value token = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || init.getType().isa()) && @@ -392,31 +391,31 @@ // Create the parallel wrapper. auto ompParallel = rewriter.create(loc); { + OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); - // Replace SCF yield with OpenMP yield. { - OpBuilder::InsertionGuard innerGuard(rewriter); - rewriter.setInsertionPointToEnd(parallelOp.getBody()); - assert(llvm::hasSingleElement(parallelOp.getRegion()) && - "expected scf.parallel to have one block"); - rewriter.replaceOpWithNewOp( - parallelOp.getBody()->getTerminator(), ValueRange()); + auto scope = rewriter.create(parallelOp.getLoc(), + TypeRange()); + rewriter.create(loc); + OpBuilder::InsertionGuard allocaGuard(rewriter); + rewriter.createBlock(&scope.getBodyRegion()); + rewriter.setInsertionPointToStart(&scope.getBodyRegion().front()); + + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.getLowerBound(), + parallelOp.getUpperBound(), parallelOp.getStep()); + rewriter.create(loc); + + rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), + loop.region().begin()); + if (!reductionVariables.empty()) { + loop.reductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + loop.reduction_varsMutable().append(reductionVariables); } - - // Replace the loop. - auto loop = rewriter.create( - parallelOp.getLoc(), parallelOp.getLowerBound(), - parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create(loc); - - rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), - loop.region().begin()); - if (!reductionVariables.empty()) { - loop.reductionsAttr( - ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); - loop.reduction_varsMutable().append(reductionVariables); } } @@ -429,7 +428,6 @@ } rewriter.replaceOp(parallelOp, results); - rewriter.create(loc, token); return success(); } }; @@ -438,7 +436,8 @@ static LogicalResult applyPatterns(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -202,62 +203,6 @@ context); } -//===----------------------------------------------------------------------===// -// AllocaScopeOp -//===----------------------------------------------------------------------===// - -void AllocaScopeOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - - p << ' '; - if (!results().empty()) { - p << " -> (" << getResultTypes() << ")"; - printBlockTerminators = true; - } - p << ' '; - p.printRegion(bodyRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); - p.printOptionalAttrDict((*this)->getAttrs()); -} - -ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { - // Create a region for the body. - result.regions.reserve(1); - Region *bodyRegion = result.addRegion(); - - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); - - // Parse the body region. - if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), - result.location); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - return success(); -} - -LogicalResult AllocaScopeOp::verify() { - return RegionBranchOpInterface::verifyTypes(*this); -} - -void AllocaScopeOp::getSuccessorRegions( - Optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - if (index.hasValue()) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - - regions.push_back(RegionSuccessor(&bodyRegion())); -} - //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// @@ -2446,6 +2391,151 @@ return OpFoldResult(); } +//===----------------------------------------------------------------------===// +// AllocaScopeOp +//===----------------------------------------------------------------------===// + +void AllocaScopeOp::print(OpAsmPrinter &p) { + bool printBlockTerminators = false; + + p << ' '; + if (!results().empty()) { + p << " -> (" << getResultTypes() << ")"; + printBlockTerminators = true; + } + p << ' '; + p.printRegion(bodyRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/printBlockTerminators); + p.printOptionalAttrDict((*this)->getAttrs()); +} + +ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { + // Create a region for the body. + result.regions.reserve(1); + Region *bodyRegion = result.addRegion(); + + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Parse the body region. + if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), + result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +LogicalResult AllocaScopeOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + +void AllocaScopeOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + regions.push_back(RegionSuccessor(&bodyRegion())); +} + +bool isStackAlloca(Operation *op) { + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return false; + if (auto effect = interface.getEffectOnValue( + op->getResult(0))) { + if (isa( + effect->getResource())) + return true; + } + return false; +} + +// Inline an AllocaScopeOp if either the direct parent is an allocation scope +// or it contains no allocation. +struct AllocaScopeInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + if (!(op->getParentOp()->hasTrait() || + !op->walk([&](Operation *alloc) { + if (isStackAlloca(alloc)) + return WalkResult::interrupt(); + return WalkResult::skip(); + }).wasInterrupted())) + return failure(); + Block *block = &op.getRegion().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + return success(); + } +}; + +// Move allocations into an allocation scope, if it is legal to +// move them (e.g. their operands are available at the location +// the op would be moved to). +struct AllocaScopeHoister : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + if (op->getParentOp()->hasTrait()) + return failure(); + + Operation *lastParentWithoutScope = op->getParentOp(); + while (!lastParentWithoutScope->getParentOp() + ->hasTrait()) { + lastParentWithoutScope = lastParentWithoutScope->getParentOp(); + } + Operation *scope = lastParentWithoutScope->getParentOp(); + assert(scope->hasTrait()); + + SmallVector toHoist; + op->walk([&](Operation *alloc) { + if (!isStackAlloca(alloc)) + return WalkResult::skip(); + + // If any operand is not defined before the location of + // lastParentWithoutScope (i.e. where we would hoist to), skip. + if (llvm::any_of(alloc->getOperands(), [&](Value v) { + return llvm::any_of( + lastParentWithoutScope->getRegions(), + [&](Region &r) { return r.isAncestor(v.getParentRegion()); }); + })) + return WalkResult::skip(); + toHoist.push_back(alloc); + return WalkResult::advance(); + }); + + if (!toHoist.size()) + return failure(); + rewriter.setInsertionPoint(lastParentWithoutScope); + for (auto op : toHoist) { + auto cloned = rewriter.clone(*op); + rewriter.replaceOp(op, cloned->getResults()); + } + return success(); + } +}; + +void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -21,12 +21,12 @@ %arg3 : index, %arg4 : index) { // CHECK: %[[CST:.*]] = arith.constant 0.0 // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 - // CHECK: llvm.intr.stacksave // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32 // CHECK: llvm.store %[[CST]], %[[BUF]] %step = arith.constant 1 : index %zero = arith.constant 0.0 : f32 // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]] scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) @@ -43,7 +43,6 @@ } // CHECK: omp.terminator // CHECK: llvm.load %[[BUF]] - // CHECK: llvm.intr.stackrestore return } @@ -162,6 +161,7 @@ // CHECK: llvm.store %[[IONE]], %[[BUF2]] // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -4,6 +4,7 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () @@ -20,9 +21,11 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () @@ -41,6 +44,7 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> () @@ -52,6 +56,7 @@ // CHECK: } // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> () diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -552,3 +552,62 @@ // CHECK-LABEL: func @self_copy // CHECK-NEXT: return + +// ----- + +func @scopeMerge() { + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + return +} +// CHECK: func @scopeMerge() { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +func @scopeMerge2() { + "test.region"() ({ + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + }) : () -> () + return +} + +// CHECK: func @scopeMerge2() { +// CHECK-NEXT: "test.region"() ({ +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: } +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +func @scopeMerge3() { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + }) : () -> () + return +} + +// CHECK-NEXT: func @scopeMerge3() { +// CHECK-NEXT: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK-NEXT: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK-NEXT: "test.region"() ({ +// CHECK-NEXT: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: }