diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -31,9 +31,9 @@ class FlatAffineConstraints; class OpBuilder; -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine execute_region. A value of index type +/// defined at the top level is always a valid symbol. bool isTopLevelValue(Value value); /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data @@ -457,12 +457,22 @@ SmallVectorImpl &results); }; -/// Returns true if the given Value can be used as a dimension id. +/// Returns true if the given Value can be used as a dimension id in the closest +/// surrounding op that is isolated from above or an affine execute_region +/// enclosing this value's definition or block argument appearance. bool isValidDim(Value value); -/// Returns true if the given Value can be used as a symbol. +/// Returns true if the given Value can be used as a dimension id in `region`. +bool isValidDim(Value value, Region *region); + +/// Returns true if the given value can be used as a symbol in the closest +/// op that is isolated from above or an affine execute_region op enclosing this +/// value's definition or block argument appearance. bool isValidSymbol(Value value); +/// Returns true if the given Value can be used as a symbol for `region`. +bool isValidSymbol(Value value, Region *region); + /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands /// 2. drop unused dims and symbols from map diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -83,12 +83,20 @@ /// Returns the affine value map computed from this operation. AffineValueMap getAffineValueMap(); - /// Returns true if the result of this operation can be used as dimension id. + /// Returns true if the result of this operation can be used as dimension + /// id. bool isValidDim(); + /// Returns true if the result of this operation can be used as dimension id + /// within the region 'region'. + bool isValidDim(Region *region); + /// Returns true if the result of this operation is a symbol. bool isValidSymbol(); + /// Returns true if the result of this operation is a symbol in 'region'. + bool isValidSymbol(Region *region); + operand_range getMapOperands() { return getOperands(); } }]; @@ -612,4 +620,72 @@ let verifier = ?; } +def AffineExecuteRegionOp : Affine_Op<"execute_region"> { + let summary = "execute_region operation"; + let description = [{ + The `affine.execute_region` op introduces a new symbol context for affine + operations. It holds a single region, which can be a list of one or more + blocks, and its semantics are to execute its region exactly once. The op's + region can have zero or more arguments, each of which can only be a + memref. The operands bind 1:1 to its region's arguments. The op can't use + any memrefs defined outside of it, but can use any other SSA values that + dominate it. The results of a execute_region op match 1:1 with the return + values from its region's blocks; + + Examples: + + ```mlir + affine.for %i = 0 to 128 { + affine.execute_region [%rI, %rM] = (%I, %M) + : (memref<128xi32>, memref<24xf32>) -> () { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + affine.load %rM[%index]: memref<24xf32> + return + } + } + ``` + + ```mlir + affine.for %i = 0 to %n { + affine.execute_region [] = () : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + ``` + + ```mlir + affine.for %i = 0 to %n { + affine.execute_region : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + ``` + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "ValueRange memrefs"> + ]; + + // TODO: canonicalizations related to memrefs. + let hasCanonicalizer = 0; +} #endif // AFFINE_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -84,18 +84,46 @@ return builder.create(loc, type, value); } -/// A utility function to check if a given region is attached to a function. -static bool isFunctionRegion(Region *region) { - return llvm::isa(region->getParentOp()); +/// A utility function to check if a given region is attached to an op isolated +/// from above or an affine scope. +static bool isIsolatedOrAffineExecuteRegion(Region *region) { + return region->getParentOp()->isKnownIsolatedFromAbove() || + isa(region->getParentOp()); } -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine.execute_region op. A value of index type +/// defined at the top level is always a valid symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) - return isFunctionRegion(arg.getOwner()->getParent()); - return isFunctionRegion(value.getDefiningOp()->getParentRegion()); + return isIsolatedOrAffineExecuteRegion(arg.getOwner()->getParent()); + return isIsolatedOrAffineExecuteRegion( + value.getDefiningOp()->getParentRegion()); +} + +/// A utility function to check if a value is defined at the top level of +/// `region`. A value of index type defined at the top level is always a +/// valid symbol. +static bool isTopLevelValue(Value value, Region *region) { + if (auto arg = value.dyn_cast()) + return arg.getOwner()->getParentOp() == region->getParentOp(); + return value.getDefiningOp()->getParentOp() == region->getParentOp(); +} + +/// Returns the closest region surrounding 'op' that is held either by an +/// AffineExecuteRegionOp or an op isolated from above (eg. FuncOp). Asserts if +/// called on a top-level op. +// TODO: getAffineScope should be publicly exposed for affine passes/utilities. +static Region *getAffineScope(Operation *op) { + auto *curOp = op; + while (auto *parentOp = curOp->getParentOp()) { + if (llvm::isa(parentOp) || + parentOp->isKnownIsolatedFromAbove()) + return curOp->getParentRegion(); + curOp = parentOp; + } + + llvm_unreachable("op doesn't have an enclosing affine scope"); } // Value can be used as a dimension id if it is valid as a symbol, or @@ -106,43 +134,64 @@ if (!value.getType().isIndex()) return false; - if (auto *op = value.getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidDim(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelValue(dimOp.getOperand()); - return false; - } - // This value has to be a block argument of a FuncOp, an 'affine.for', or an - // 'affine.parallel'. + if (auto *op = value.getDefiningOp()) + return isValidDim(value, getAffineScope(op)); + + // This value has to be a block argument for an op isolated from above or an + // affine.for. (An affine.execute_region can't have index type arguments.) auto *parentOp = value.cast().getOwner()->getParentOp(); - return isa(parentOp) || isa(parentOp) || + return parentOp->isKnownIsolatedFromAbove() || isa(parentOp) || isa(parentOp); } +// Value can be used as a dimension id if it is valid as a symbol, or it is an +// induction variable, or it is a result of an affine apply operation with +// dimension id arguments. +bool mlir::isValidDim(Value value, Region *region) { + // The value must be an index type. + if (!value.getType().isIndex()) + return false; + + // All valid symbols are okay. + if (isValidSymbol(value, region)) + return true; + + auto *op = value.getDefiningOp(); + if (!op) { + // This value has to be a block argument for a affine.for or + // affine.parallel. + auto *parentOp = value.cast().getOwner()->getParentOp(); + return isa(parentOp) || isa(parentOp); + } + + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(op)) + return applyOp.isValidDim(region); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast(op)) + return isTopLevelValue(dimOp.getOperand()); + return false; +} + /// Returns true if the 'index' dimension of the `memref` defined by -/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol +/// for 'op'. template -static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, - unsigned index) { +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, + Region *region) { auto memRefType = memrefDefOp.getType(); // Statically shaped. if (!memRefType.isDynamicDim(index)) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); - return isValidSymbol( - *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); + return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), + region); } /// Returns true if the result of the dim op is a valid symbol. -static bool isDimOpValidSymbol(DimOp dimOp) { +static bool isDimOpValidSymbol(DimOp dimOp, Region *region) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (isTopLevelValue(dimOp.getOperand())) @@ -152,43 +201,82 @@ // whose corresponding size is a valid symbol. unsigned index = dimOp.getIndex(); if (auto viewOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(viewOp, index); + return isMemRefSizeValidSymbol(viewOp, index, region); if (auto subViewOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(subViewOp, index); + return isMemRefSizeValidSymbol(subViewOp, index, region); if (auto allocOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(allocOp, index); + return isMemRefSizeValidSymbol(allocOp, index, region); return false; } // Value can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments, or a result of the dim op on a memref satisfying certain -// constraints. +// the top level of the enclosing affine scope (affine.execute_region or an op +// isolated from above) or dominates such a scope, or it is a result of affine +// apply operation with symbol arguments, or a result of the dim op on a memref +// whose corresponding size is a valid symbol. bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; - if (auto *op = value.getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidSymbol(); - if (auto dimOp = dyn_cast(op)) { - return isDimOpValidSymbol(dimOp); - } + // Check that the value is a top level value. + if (isTopLevelValue(value)) + return true; + + if (auto *op = value.getDefiningOp()) + return isValidSymbol(value, getAffineScope(op)); + + return false; +} + +// Value can be used as a symbol in `region` if it is a constant, or it is +// defined at the top level of 'region' or dominates 'region's parent op, or it +// is the result of an affine apply operation with symbol arguments, or a result +// of the dim op on a memref whose corresponding size is a valid symbol. +bool mlir::isValidSymbol(Value value, Region *region) { + // The value must be an index type. + if (!value.getType().isIndex()) + return false; + + // A top-level value is a valid symbol. + if (::isTopLevelValue(value, region)) + return true; + + auto *defOp = value.getDefiningOp(); + if (!defOp) { + // A block argument that is not a top-level value is a valid symbol if it + // dominates region's parent op. + if (!region->getParentOp()->isKnownIsolatedFromAbove()) + if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) + return isValidSymbol(value, parentOpRegion); + return false; } - // Otherwise, check that the value is a top level value. - return isTopLevelValue(value); + + // Constant operation is ok. + if (isa(defOp)) + return true; + + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(defOp)) + return applyOp.isValidSymbol(region); + + // Dim op results could be valid symbols at any level. + if (auto dimOp = dyn_cast(defOp)) + return isDimOpValidSymbol(dimOp, region); + + // Check for values dominating `region`'s parent op. + if (!region->getParentOp()->isKnownIsolatedFromAbove()) + if (auto *parentRegion = region->getParentOp()->getParentRegion()) + return isValidSymbol(value, parentRegion); + + return false; } // Returns true if 'value' is a valid index to an affine operation (e.g. -// affine.load, affine.store, affine.dma_start, affine.dma_wait). -// Returns false otherwise. -static bool isValidAffineIndexOperand(Value value) { - return isValidDim(value) || isValidSymbol(value); +// affine.load, affine.store, affine.dma_start, affine.dma_wait) inside +// 'region'. Returns false otherwise. +static bool isValidAffineIndexOperand(Value value, Region *region) { + return isValidDim(value, region) || isValidSymbol(value, region); } /// Utility function to verify that a set of operands are valid dimension and @@ -273,6 +361,13 @@ [](Value op) { return mlir::isValidDim(op); }); } +// The result of the affine apply operation can be used as a dimension id if all +// its operands are valid dimension ids. +bool AffineApplyOp::isValidDim(Region *region) { + return llvm::all_of(getOperands(), + [&](Value op) { return mlir::isValidDim(op, region); }); +} + // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { @@ -280,6 +375,14 @@ [](Value op) { return mlir::isValidSymbol(op); }); } +// The result of the affine apply operation can be used as a symbol if all its +// operands are symbols. +bool AffineApplyOp::isValidSymbol(Region *region) { + return llvm::all_of(getOperands(), [&](Value operand) { + return mlir::isValidSymbol(operand, region); + }); +} + OpFoldResult AffineApplyOp::fold(ArrayRef operands) { auto map = getAffineMap(); @@ -948,22 +1051,23 @@ return emitOpError("incorrect number of operands"); } + Region *scope = getAffineScope(*this); for (auto idx : getSrcIndices()) { if (!idx.getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { if (!idx.getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("tag index must be a dimension or symbol identifier"); } return success(); @@ -1039,7 +1143,8 @@ for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + Region *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1822,7 +1927,8 @@ for (auto idx : getMapOperands()) { if (!idx.getType().isIndex()) return emitOpError("index to load must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + Region *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1920,7 +2026,8 @@ for (auto idx : getMapOperands()) { if (!idx.getType().isIndex()) return emitOpError("index to store must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + Region *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2140,7 +2247,8 @@ } for (auto idx : op.getMapOperands()) { - if (!isValidAffineIndexOperand(idx)) + Region *scope = getAffineScope(op); + if (!isValidAffineIndexOperand(idx, scope)) return op.emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2158,6 +2266,145 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// AffineExecuteRegionOp +//===----------------------------------------------------------------------===// +// + +// TODO: missing region body. +void AffineExecuteRegionOp::build(Builder *builder, OperationState &result, + ValueRange memrefs) { + // Create a region and an empty entry block. The arguments of the region are + // the supplied memrefs. + Region *region = result.addRegion(); + Block *body = new Block(); + region->push_back(body); + body->addArguments(memrefs.getTypes()); + + // Set the operands list as resizable so that we can add memrefs. + result.setOperandListToResizable(); +} + +static LogicalResult verify(AffineExecuteRegionOp op) { + // All memref uses in the execute_region region should be explicitly captured. + // FIXME: change this walk to an affine walk that doesn't walk inner + // execute_regions. + DenseSet memrefsUsed; + op.region().walk([&](Operation *innerOp) { + for (auto v : innerOp->getOperands()) + if (v.getType().isa()) + memrefsUsed.insert(v); + }); + + // For each memref use, ensure either an execute_region argument or a local + // def. + for (auto memref : memrefsUsed) { + if (auto arg = memref.dyn_cast()) + if (arg.getOwner()->getParent()->getParentOp() == op) + continue; + if (auto *defOp = memref.getDefiningOp()) + // FIXME: this will only work if the memrefs collected above didn't + // include any from inner execute_regions. + if (defOp->getParentOfType() == op) + continue; + return op.emitOpError("incoming memref not explicitly captured"); + } + + // Verify that the region arguments match operands. + auto &entryBlock = op.region().front(); + if (entryBlock.getNumArguments() != op.getNumOperands()) + return op.emitOpError("region argument count does not match operand count"); + + for (auto argEn : llvm::enumerate(entryBlock.getArguments())) { + if (op.getOperand(argEn.index()).getType() != argEn.value().getType()) + return op.emitOpError("region argument ") + << argEn.index() << " does not match corresponding operand"; + } + + return success(); +} + +// Custom form syntax. +// +// (ssa-id `=`)? `affine.execute_region` (`[` memref-region-arg-list `]` +// `=` `(` memref-use-list `)`)? +// `:` memref-type-list-parens `->` function-result-type `{` +// block+ +// `}` +// +// Ex: +// +// affine.execute_region [%rI, %rM] = (%I, %M) +// : (memref<128xi32>, memref<1024xf32>) -> () { +// %idx = affine.load %rI[%i] : memref<128xi32> +// %index = index_cast %idx : i32 to index +// affine.load %rM[%index]: memref<1024xf32> +// return +// } +// +static ParseResult parseAffineExecuteRegionOp(OpAsmParser &parser, + OperationState &result) { + // Memref operands. + SmallVector memrefs; + + // Region arguments to be created. + SmallVector regionMemRefs; + + // The execute_region op has the same type signature as a function. + FunctionType opType; + + // Parse the memref assignments. + auto argLoc = parser.getCurrentLocation(); + if (parser.parseRegionArgumentList(regionMemRefs, + OpAsmParser::Delimiter::Square) || + parser.parseEqual() || + parser.parseOperandList(memrefs, OpAsmParser::Delimiter::Paren)) + return failure(); + + if (memrefs.size() != regionMemRefs.size()) + return parser.emitError(parser.getNameLoc(), + "incorrect number of memref captures"); + + if (parser.parseColonType(opType) || + parser.addTypesToList(opType.getResults(), result.types)) + return failure(); + + auto memrefTypes = opType.getInputs(); + if (parser.resolveOperands(memrefs, memrefTypes, argLoc, result.operands)) + return failure(); + + // Introduce and parse body region, and the optional attribute list. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionMemRefs, memrefTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Set the operands list as resizable so that we can modify operands. + result.setOperandListToResizable(); + return success(); +} + +static void print(OpAsmPrinter &p, AffineExecuteRegionOp op) { + p << AffineExecuteRegionOp::getOperationName() << " ["; + // TODO: consider shadowing region arguments. + p.printOperands(op.region().front().getArguments()); + p << "] = ("; + auto operands = op.getOperands(); + p.printOperands(operands); + p << ") "; + + SmallVector argTypes(op.getOperandTypes()); + p << " : " + << FunctionType::get(argTypes, op.getResultTypes(), + op.getOperation()->getContext()); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op.getAttrs()); +} + //===----------------------------------------------------------------------===// // AffineParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/execute-region.mlir b/mlir/test/Dialect/Affine/execute-region.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/execute-region.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: @arbitrary_bound +func @arbitrary_bound(%n : index) { + affine.for %i = 0 to %n { + affine.execute_region [] = () : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "test.foo"() : () -> () + } + return + } + // CHECK: affine.execute_region [] = () : () -> () { + // CHECK-NEXT: call @powi + // CHECK-NEXT: affine.for + // CHECK-NEXT: "test.foo"() + // CHECK-NEXT: } + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +func @powi(index) -> index + +// CHECK-LABEL: func @arbitrary_mem_access +func @arbitrary_mem_access(%I: memref<128xi32>, %M: memref<1024xf32>) { + affine.for %i = 0 to 128 { + // CHECK: %{{.*}} = affine.execute_region [{{.*}}] = ({{.*}}) : (memref<128xi32>, memref<1024xf32>) -> f32 + %ret = affine.execute_region [%rI, %rM] = (%I, %M) : (memref<128xi32>, memref<1024xf32>) -> f32 { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + %v = affine.load %rM[%index]: memref<1024xf32> + return %v : f32 + } + } + return +} + +// CHECK-LABEL: @symbol_check +func @symbol_check(%B: memref<100xi32>, %A: memref<100xf32>) { + %cf1 = constant 1.0 : f32 + affine.for %i = 0 to 100 { + %v = affine.load %B[%i] : memref<100xi32> + %vo = index_cast %v : i32 to index + // CHECK: affine.execute_region [%{{.*}}] = (%{{.*}}) : (memref<100xf32>) -> () { + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + // %vi is now a symbol here. + %vi = index_cast %v : i32 to index + affine.load %rA[%vi] : memref<100xf32> + // %vo is also a symbol (dominates the execute_region). + affine.load %rA[%vo] : memref<100xf32> + return + } + // CHECK: index_cast + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.load + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +// CHECK-LABEL: func @test_more_symbol_validity +func @test_more_symbol_validity(%A: memref<100xf32>, %pos : index) { + %c5 = constant 5 : index + affine.for %i = 0 to 100 { + %sym = call @external() : () -> (index) + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + affine.load %rA[symbol(%pos) + symbol(%sym) + %c5] : memref<100xf32> + return + } + } + affine.execute_region [%rA] = (%A) : (memref<100xf32>) -> () { + affine.load %rA[symbol(%pos) + %c5] : memref<100xf32> + return + } + return +} + +func @external() -> (index) + +// CHECK-LABEL: func @search +func @search(%A : memref, %S : memref, %key : i32) { + %ni = dim %A, 0 : memref + %c1 = constant 1 : index + // This loop can be parallelized. + affine.for %i = 0 to %ni { + // CHECK: affine.execute_region + affine.execute_region [%rA, %rS] = (%A, %S) : (memref, memref) -> () { + %c0 = constant 0 : index + %nj = dim %rA, 1 : memref + br ^bb1(%c0 : index) + + ^bb1(%j: index): + %p1 = cmpi "slt", %j, %nj : index + cond_br %p1, ^bb2(%j : index), ^bb5 + + ^bb2(%j_arg : index): + %v = affine.load %rA[%i, %j_arg] : memref + %p2 = cmpi "eq", %v, %key : i32 + cond_br %p2, ^bb3(%j_arg : index), ^bb4(%j_arg : index) + + ^bb3(%j_arg2: index): + %j_int = index_cast %j_arg2 : index to i32 + affine.store %j_int, %rS[%i] : memref + br ^bb5 + + ^bb4(%j_arg3 : index): + %jinc = addi %j_arg3, %c1 : index + br ^bb1(%jinc : index) + + ^bb5: + return + } + } + return +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -124,7 +124,7 @@ %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref %dim = dim %0, 0 : memref - // expected-error@+1 {{operand cannot be used as a dimension id}} + // expected-error@+1 {{operand cannot be used as a symbol}} affine.if #set0(%dim)[%n0] {} } return @@ -171,6 +171,62 @@ // ----- +// CHECK-LABEL: @affine.execute_region_missing_capture +func @affine.execute_region_missing_capture(%M : memref<2xi32>) { + affine.for %i = 0 to 10 { + affine.execute_region [] = () : () -> () { + // expected-error@-1 {{incoming memref not explicitly captured}} + affine.load %M[%i] : memref<2xi32> + } + } + return +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_wrong_capture +func @affine.execute_region_wrong_capture(%s : index) { + affine.execute_region [%rS] = (%s) : (index) -> () { + // expected-error@-1 {{operand #0 must be memref}} + "use"(%s) : (index) -> () + } +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_wrong_capture +func @affine.execute_region_wrong_capture(%A : memref<2xi32>) { + affine.execute_region [] = (%A) : (memref<2xi32>) -> () { + // expected-error@-1 {{incorrect number of memref captures}} + } + return +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_region_type_mismatch +func @affine.execute_region_region_type_mismatch(%A : memref<2xi32>) { + "affine.execute_region"(%A) ({ + // expected-error@-1 {{region argument 0 does not match corresponding operand}} + ^bb0(%rA : memref<4xi32>): + return + }) : (memref<2xi32>) -> () +} + +// ----- + +// CHECK-LABEL: @affine.execute_region_region_arg_count_mismatch +func @affine.execute_region_region_arg_count_mismatch(%A : memref<2xi32>) { + "affine.execute_region"(%A) ({ + // expected-error@-1 {{region argument count does not match operand count}} + ^bb0: + return + }) : (memref<2xi32>) -> () + return +} + +// ----- + // CHECK-LABEL: @affine_max func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}