diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -113,6 +113,74 @@ let hasFolder = 1; } +def AffineExecuteRegionOp : Affine_Op<"execute_region", [AffineScope]> { + let summary = "execute_region operation"; + let description = [{ + The `affine.execute_region` op introduces a new symbol context for affine + operations. It holds a single region, which can be a list of one or more + blocks, and its semantics are to execute its region exactly once. The op's + region can have zero or more arguments, each of which can only be a + memref. The operands bind 1:1 to its region's arguments. The op can't use + any memrefs defined outside of it, but can use any other SSA values that + dominate it. The results of a execute_region op match 1:1 with the return + values from its region's blocks; + + Examples: + + ```mlir + affine.for %i = 0 to 128 { + affine.execute_region [%rI, %rM] = (%I, %M) + : (memref<128xi32>, memref<24xf32>) -> () { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + affine.load %rM[%index]: memref<24xf32> + return + } + } + ``` + + ```mlir + affine.for %i = 0 to %n { + affine.execute_region : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + ``` + + ```mlir + affine.for %i = 0 to %n { + affine.execute_region : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + ``` + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<(ins "ValueRange":$memrefs)> + ]; + + // TODO: canonicalizations related to memrefs. + let hasCanonicalizer = 0; +} + def AffineForOp : Affine_Op<"for", [ImplicitAffineTerminator, RecursiveSideEffects, DeclareOpInterfaceMethods]> { @@ -883,8 +951,11 @@ let hasFolder = 1; } -def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike, - MemRefsNormalizable]> { +def AffineYieldOp : Affine_Op<"yield", [ + NoSideEffect, Terminator, ReturnLike, MemRefsNormalizable, + ParentOneOf < + ["AffineExecuteRegionOp, AffineForOp, AffineIfOp, AffineParallelOp"] + >]> { let summary = "Yield values to parent operation"; let description = [{ "affine.yield" yields zero or more SSA values from an affine op region and diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2659,6 +2659,138 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// AffineExecuteRegionOp +//===----------------------------------------------------------------------===// +// + +// TODO: missing region body. +void AffineExecuteRegionOp::build(OpBuilder &builder, OperationState &result, + ValueRange memrefs) { + // Create a region and an empty entry block. The arguments of the region are + // the supplied memrefs. + Region *region = result.addRegion(); + Block *body = new Block(); + region->push_back(body); + body->addArguments(memrefs.getTypes()); +} + +static LogicalResult verify(AffineExecuteRegionOp op) { + // All memref uses in the execute_region region should be explicitly captured. + // FIXME: change this walk to an affine walk that doesn't walk inner + // execute_regions. + DenseSet memrefsUsed; + op.region().walk([&](Operation *innerOp) { + for (Value v : innerOp->getOperands()) + if (v.getType().isa()) + memrefsUsed.insert(v); + }); + + // For each memref use, ensure either an execute_region argument or a local + // def. + auto implicitUse = [&](Value memref) { + Operation *memrefOriginOp; + if (auto arg = memref.dyn_cast()) + memrefOriginOp = arg.getOwner()->getParentOp(); + else + memrefOriginOp = memref.getDefiningOp(); + return !op.getOperation()->isAncestor(memrefOriginOp); + }; + if (llvm::any_of(memrefsUsed, implicitUse)) + return op.emitOpError("used memref not explicitly captured"); + + // Verify that the region arguments match operands. + auto &entryBlock = op.region().front(); + if (entryBlock.getNumArguments() != op.getNumOperands()) + return op.emitOpError("region argument count does not match operand count"); + + for (auto argEn : llvm::enumerate(entryBlock.getArguments())) { + if (op.getOperand(argEn.index()).getType() != argEn.value().getType()) + return op.emitOpError("region argument ") + << argEn.index() << " does not match corresponding operand"; + } + + return success(); +} + +// Custom form syntax. +// +// (ssa-id `=`)? `affine.execute_region` (`[` memref-region-arg-list `]` +// `=` `(` memref-use-list `)`)? +// `:` memref-type-list-parens `->` function-result-type `{` +// block+ +// `}` +// +// Ex: +// +// affine.execute_region [%rI, %rM] = (%I, %M) +// : (memref<128xi32>, memref<1024xf32>) -> () { +// %idx = affine.load %rI[%i] : memref<128xi32> +// %index = index_cast %idx : i32 to index +// affine.load %rM[%index]: memref<1024xf32> +// return +// } +// +static ParseResult parseAffineExecuteRegionOp(OpAsmParser &parser, + OperationState &result) { + // Memref operands. + SmallVector memrefs; + + // Region arguments to be created. + SmallVector regionMemRefs; + + // The execute_region op has the same type signature as a function. + FunctionType opType; + + // Parse the memref assignments. + auto argLoc = parser.getCurrentLocation(); + if (parser.parseRegionArgumentList(regionMemRefs, + OpAsmParser::Delimiter::Square) || + parser.parseEqual() || + parser.parseOperandList(memrefs, OpAsmParser::Delimiter::Paren)) + return failure(); + + if (memrefs.size() != regionMemRefs.size()) + return parser.emitError(parser.getNameLoc(), + "incorrect number of memref captures"); + + if (parser.parseColonType(opType) || + parser.addTypesToList(opType.getResults(), result.types)) + return failure(); + + auto memrefTypes = opType.getInputs(); + if (parser.resolveOperands(memrefs, memrefTypes, argLoc, result.operands)) + return failure(); + + // Introduce and parse body region, and the optional attribute list. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionMemRefs, memrefTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, AffineExecuteRegionOp op) { + p << AffineExecuteRegionOp::getOperationName() << " ["; + // TODO: consider shadowing region arguments. + p.printOperands(op.region().front().getArguments()); + p << "] = ("; + auto operands = op.getOperands(); + p.printOperands(operands); + p << ") "; + + SmallVector argTypes(op.getOperandTypes()); + p << " : " + << FunctionType::get(op->getContext(), argTypes, op.getResultTypes()); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op->getAttrs()); +} + //===----------------------------------------------------------------------===// // AffineParallelOp //===----------------------------------------------------------------------===// @@ -3277,8 +3409,6 @@ auto results = parentOp->getResults(); auto operands = op.getOperands(); - if (!isa(parentOp)) - return op.emitOpError() << "only terminates affine.if/for/parallel regions"; if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of yield must have same number of " "results as the yield operands"; diff --git a/mlir/test/Dialect/Affine/execute-region.mlir b/mlir/test/Dialect/Affine/execute-region.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/execute-region.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @arbitrary_bound +func @arbitrary_bound(%n : index) { + affine.for %i = 0 to %n { + affine.execute_region [] = () : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "test.foo"() : () -> () + } + affine.yield + } + // CHECK: affine.execute_region [] = () : () -> () { + // CHECK-NEXT: call @powi + // CHECK-NEXT: affine.for + // CHECK-NEXT: "test.foo"() + // CHECK-NEXT: } + // CHECK-NEXT: affine.yield + // CHECK-NEXT: } + } + return +} + +func private @powi(index) -> index + +// CHECK-LABEL: func @arbitrary_mem_access +func @arbitrary_mem_access(%I: memref<128xi32>, %M: memref<1024xf32>) { + affine.for %i = 0 to 128 { + // CHECK: %{{.*}} = affine.execute_region [{{.*}}] = ({{.*}}) : (memref<128xi32>, memref<1024xf32>) -> f32 + %ret = affine.execute_region [%rI, %rM] = (%I, %M) : (memref<128xi32>, memref<1024xf32>) -> f32 { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + %v = affine.load %rM[%index]: memref<1024xf32> + affine.yield %v : f32 + } + } + return +} + +// CHECK-LABEL: @symbol_check +func @symbol_check(%B: memref<100xi32>, %A: memref<100xf32>) { + %cf1 = constant 1.0 : f32 + affine.for %i = 0 to 100 { + %v = affine.load %B[%i] : memref<100xi32> + %vo = index_cast %v : i32 to index + // CHECK: affine.execute_region [%{{.*}}] = (%{{.*}}) : (memref<100xf32>) -> () { + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + // %vi is now a symbol here. + %vi = index_cast %v : i32 to index + affine.load %rA[%vi] : memref<100xf32> + // %vo is also a symbol (dominates the execute_region). + affine.load %rA[%vo] : memref<100xf32> + affine.yield + } + // CHECK: index_cast + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.yield + // CHECK-NEXT: } + } + return +} + +// CHECK-LABEL: func @test_more_symbol_validity +func @test_more_symbol_validity(%A: memref<100xf32>, %pos : index) { + %c5 = constant 5 : index + affine.for %i = 0 to 100 { + %sym = call @external() : () -> (index) + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + affine.load %rA[symbol(%pos) + symbol(%sym) + %c5] : memref<100xf32> + affine.yield + } + } + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + affine.load %rA[symbol(%pos) + %c5] : memref<100xf32> + affine.yield + } + return +} + +func private @external() -> (index) + +// CHECK-LABEL: func @search +func @search(%A : memref, %S : memref, %key : i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %ni = memref.dim %A, %c0 : memref + // This loop can be parallelized. + affine.for %i = 0 to %ni { + // CHECK: affine.execute_region + affine.execute_region [%rA, %rS] = (%A, %S) : (memref, memref) -> () { + %nj = memref.dim %rA, %c1 : memref + br ^bb1(%c0 : index) + + ^bb1(%j: index): + %p1 = cmpi "slt", %j, %nj : index + cond_br %p1, ^bb2(%j : index), ^bb5 + + ^bb2(%j_arg : index): + %v = affine.load %rA[%i, %j_arg] : memref + %p2 = cmpi "eq", %v, %key : i32 + cond_br %p2, ^bb3(%j_arg : index), ^bb4(%j_arg : index) + + ^bb3(%j_arg2: index): + %j_int = index_cast %j_arg2 : index to i32 + affine.store %j_int, %rS[%i] : memref + br ^bb5 + + ^bb4(%j_arg3 : index): + %jinc = addi %j_arg3, %c1 : index + br ^bb1(%jinc : index) + + ^bb5: + affine.yield + } + } + return +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -142,6 +142,62 @@ // ----- +// CHECK-LABEL: @affine.execute_region_missing_capture +func @affine.execute_region_missing_capture(%M : memref<2xi32>) { + affine.for %i = 0 to 10 { + affine.execute_region [] = () : () -> () { + // expected-error@-1 {{used memref not explicitly captured}} + affine.load %M[%i] : memref<2xi32> + } + } + return +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_wrong_capture +func @affine.execute_region_wrong_capture(%s : index) { + affine.execute_region [%rS] = (%s) : (index) -> () { + // expected-error@-1 {{operand #0 must be memref}} + "use"(%s) : (index) -> () + } +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_wrong_capture +func @affine.execute_region_wrong_capture(%A : memref<2xi32>) { + affine.execute_region [] = (%A) : (memref<2xi32>) -> () { + // expected-error@-1 {{incorrect number of memref captures}} + } + return +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_region_type_mismatch +func @affine.execute_region_region_type_mismatch(%A : memref<2xi32>) { + "affine.execute_region"(%A) ({ + // expected-error@-1 {{region argument 0 does not match corresponding operand}} + ^bb0(%rA : memref<4xi32>): + return + }) : (memref<2xi32>) -> () +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_region_arg_count_mismatch +func @affine.execute_region_region_arg_count_mismatch(%A : memref<2xi32>) { + "affine.execute_region"(%A) ({ + // expected-error@-1 {{region argument count does not match operand count}} + ^bb0: + return + }) : (memref<2xi32>) -> () + return +} + +// ----- + func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) { // expected-error@+1 {{operand count and affine map dimension and symbol count must match}} %0 = affine.min affine_map<(d0) -> (d0)> (%arg0, %arg1)