diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -12,6 +12,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -9,6 +9,7 @@ #ifndef MEMREF_OPS #define MEMREF_OPS +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/CastInterfaces.td" @@ -198,6 +199,80 @@ }]; } +//===----------------------------------------------------------------------===// +// AllocaScopeOp +//===----------------------------------------------------------------------===// + +def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"AllocaScopeReturnOp">, + RecursiveSideEffects, + NoRegionArguments]> { + let summary = "explicitly delimited scope for stack allocation"; + let description = [{ + The `std.alloca_scope` operation represents an explicitly-delimited + scope for the alloca allocations. Any `std.alloca` operations that are + used within this scope are going to be cleaned up automatically once + the control-flow exits the nested region. For example: + + ```mlir + std.alloca_scope { + %myalloca = std.alloca(): memref<4x3xf32> + ... + } + ``` + + Here, `%myalloca` memref is valid within the explicitly delimited scope + and is automatically deallocated at the end of the given region. + + `std.alloca_scope` may also return results that are defined in the nested + region. To return a value, one should use `std.alloca_scope.return` + operation: + + ```mlir + %result = std.alloca_scope { + ... + std.alloca_scope.return %value + } + ``` + + If `std.alloca_scope` returns no value, the `std.alloca_scope.return ` can + be left out, and will be inserted implicitly. + }]; + + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$bodyRegion); +} + +//===----------------------------------------------------------------------===// +// AllocaScopeReturnOp +//===----------------------------------------------------------------------===// + +def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", + [HasParent<"AllocaScopeOp">, + NoSideEffect, + ReturnLike, + Terminator]> { + let summary = "terminator for alloca_scope operation"; + let description = [{ + `std.alloca_scope.return` operation returns zero or more SSA values + from the region within `std.alloca_scope`. If no values are returned, + the return operation may be omitted. Otherwise, it has be to present + to indicate which values are going to be returned. + }]; + + let arguments = (ins Variadic:$results); + let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>]; + + let assemblyFormat = + [{ attr-dict ($results^ `:` type($results))? }]; + + // No custom verification needed. + let verifier = ?; +} + + + //===----------------------------------------------------------------------===// // BufferCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1998,6 +1998,62 @@ } }; +struct AllocaScopeOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + mlir::memref::AllocaScopeOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mlir::memref::AllocaScopeOp allocaScopeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard guard(rewriter); + Location loc = allocaScopeOp.getLoc(); + + // Split the current block before the AllocaScopeOp to create the inlining + // point. + auto *currentBlock = rewriter.getInsertionBlock(); + auto *remainingOpsBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *continueBlock; + if (allocaScopeOp.getNumResults() == 0) { + continueBlock = remainingOpsBlock; + } else { + continueBlock = rewriter.createBlock(remainingOpsBlock, + allocaScopeOp.getResultTypes()); + rewriter.create(loc, remainingOpsBlock); + } + + // Inline body region. + Block *beforeBody = &allocaScopeOp.bodyRegion().front(); + Block *afterBody = &allocaScopeOp.bodyRegion().back(); + rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); + + // Save stack and then branch into the body of the region. + rewriter.setInsertionPointToEnd(currentBlock); + auto stackSaveOp = + rewriter.create(loc, getVoidPtrType()); + rewriter.create(loc, beforeBody); + + // Replace the alloca_scope return with a branch that jumps out of the body. + // Stack restore before leaving the body region. + rewriter.setInsertionPointToEnd(afterBody); + auto returnOp = + cast(afterBody->getTerminator()); + auto branchOp = rewriter.replaceOpWithNewOp( + returnOp, continueBlock, returnOp.results()); + + // Insert stack restore before jumping out the body of the region. + rewriter.setInsertionPoint(branchOp); + rewriter.create(loc, stackSaveOp); + + // Replace the op with values return from the body region. + rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); + + return success(); + } +}; + /// Copies the shaped descriptor part to (if `toDynamic` is set) or from /// (otherwise) the dynamically allocated memory for any operands that were /// unranked descriptors originally. @@ -3855,6 +3911,7 @@ AddFOpLowering, AddIOpLowering, AllocaOpLowering, + AllocaScopeOpLowering, AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -228,6 +228,65 @@ context); } +//===----------------------------------------------------------------------===// +// AllocaScopeOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, AllocaScopeOp &op) { + bool printBlockTerminators = false; + + p << AllocaScopeOp::getOperationName() << " "; + if (!op.results().empty()) { + p << " -> (" << op.getResultTypes() << ")"; + printBlockTerminators = true; + } + p.printRegion(op.bodyRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/printBlockTerminators); + p.printOptionalAttrDict(op->getAttrs()); +} + +static ParseResult parseAllocaScopeOp(OpAsmParser &parser, + OperationState &result) { + // Create a region for the body. + result.regions.reserve(1); + Region *bodyRegion = result.addRegion(); + + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Prase the body region. + if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), + result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static LogicalResult verify(AllocaScopeOp op) { + if (failed(RegionBranchOpInterface::verifyTypes(op))) + return failure(); + + return success(); +} + +void AllocaScopeOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + regions.push_back(RegionSuccessor(&bodyRegion())); +} + //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-alloca-scope.mlir b/mlir/test/Conversion/StandardToLLVM/convert-alloca-scope.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-alloca-scope.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s + +// CHECK-LABEL: llvm.func @empty +func @empty() { +// CHECK: llvm.intr.stacksave +// CHECK: llvm.br + memref.alloca_scope { + memref.alloca_scope.return + } +// CHECK: llvm.intr.stackrestore +// CHECK: llvm.br +// CHECK: llvm.return + return +} + +// CHECK-LABEL: llvm.func @returns_nothing +func @returns_nothing(%b: f32) { + %a = constant 10.0 : f32 +// CHECK: llvm.intr.stacksave + memref.alloca_scope { + %c = std.addf %a, %b : f32 + memref.alloca_scope.return + } +// CHECK: llvm.intr.stackrestore +// CHECK: llvm.return + return +} + +// CHECK-LABEL: llvm.func @returns_one_value +func @returns_one_value(%b: f32) -> f32 { + %a = constant 10.0 : f32 +// CHECK: llvm.intr.stacksave + %result = memref.alloca_scope -> f32 { + %c = std.addf %a, %b : f32 + memref.alloca_scope.return %c: f32 + } +// CHECK: llvm.intr.stackrestore +// CHECK: llvm.return + return %result : f32 +} + +// CHECK-LABEL: llvm.func @returns_multiple_values +func @returns_multiple_values(%b: f32) -> f32 { + %a = constant 10.0 : f32 +// CHECK: llvm.intr.stacksave + %result1, %result2 = memref.alloca_scope -> (f32, f32) { + %c = std.addf %a, %b : f32 + %d = std.subf %a, %b : f32 + memref.alloca_scope.return %c, %d: f32, f32 + } +// CHECK: llvm.intr.stackrestore +// CHECK: llvm.return + %result = std.addf %result1, %result2 : f32 + return %result : f32 +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -76,3 +76,12 @@ memref.dealloc %1 : memref<*xf32> return } + + +// CHECK-LABEL: func @memref_alloca_scope +func @memref_alloca_scope() { + memref.alloca_scope { + memref.alloca_scope.return + } + return +}