Index: mlir/docs/Dialects/Affine.md =================================================================== --- mlir/docs/Dialects/Affine.md +++ mlir/docs/Dialects/Affine.md @@ -60,20 +60,25 @@ ### Restrictions on Dimensions and Symbols The affine dialect imposes certain restrictions on dimension and symbolic -identifiers to enable powerful analysis and transformation. A symbolic -identifier can be bound to an SSA value that is either an argument to the -function, a value defined at the top level of that function (outside of all -loops and if operations), the result of a -[`constant` operation](Standard.md#constant-operation), or the result of an -[`affine.apply` operation](#affineapply-operation) that recursively takes as -arguments any symbolic identifiers, or the result of a [`dim` +identifiers to enable powerful analysis and transformation. An SSA value is a +valid symbol if it is either (1) a region argument for an op that is either +"isolated from above" (like the FuncOp) or is an affine graybox op, (2) a value +defined at the top level of (outside of all loops, if operations, or other +operations with regions) of an affine graybox op or an op "isolated from above", +(3) a value that dominates the closest enclosing affine graybox or an op +"isolated from above", (4) the result of a [`constant` +operation](Standard.md#constant-operation), (4) the result of an [`affine.apply` +operation](#affineapply-operation) that recursively takes as arguments any +symbolic identifiers, or (5) the result of a [`dim` operation](Standard.md#dim-operation) on either a memref that is a function argument or a memref where the corresponding dimension is either static or a -dynamic one in turn bound to a symbolic identifier. Dimensions may be bound not -only to anything that a symbol is bound to, but also to induction variables of -enclosing [`affine.for` operations](#affinefor-operation), and the result of an -[`affine.apply` operation](#affineapply-operation) (which recursively may use -other dimensions and symbols). +dynamic one in turn bound to a symbolic identifier. Note that as a result of +(3), symbol validity is sensitive to the location at which the value binds to +the symbol. Dimensions may be bound not only to anything that a symbol is bound +to, but also to induction variables of enclosing [`affine.for` +operations](#affinefor-operation), and the result of an [`affine.apply` +operation](#affineapply-operation) (which recursively may use other dimensions +and symbols). ### Affine Expressions Index: mlir/include/mlir/Dialect/AffineOps/AffineOps.h =================================================================== --- mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -29,9 +29,9 @@ class FlatAffineConstraints; class OpBuilder; -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine graybox. A value of index type defined +/// at the top level is always a valid symbol. bool isTopLevelValue(Value value); class AffineOpsDialect : public Dialect { @@ -73,12 +73,22 @@ return getAttrOfType("map").getValue(); } - /// Returns true if the result of this operation can be used as dimension id. + /// Returns true if the result of this operation can be used as dimension id + /// in its immediately surrounding affine scope. bool isValidDim(); - /// Returns true if the result of this operation is a symbol. + /// Returns true if the result of this operation can be used as dimension id + /// within the region of the op 'opWithRegion'. + bool isValidDim(Operation *opWithRegion); + + /// Returns true if the result of this operation is a symbol in its + /// immediately surrounding affine scope. bool isValidSymbol(); + /// Returns true if the result of this operation is a symbol in the region of + /// 'opWithRegion'. + bool isValidSymbol(Operation *opWithRegion); + static StringRef getOperationName() { return "affine.apply"; } operand_range getMapOperands() { return getOperands(); } @@ -514,12 +524,24 @@ SmallVectorImpl &results); }; -/// Returns true if the given Value can be used as a dimension id. +/// Returns true if the given Value can be used as a dimension id in the closest +/// op that is isolated from above or an affine graybox enclosing this value's +/// definition or block argument appearance. bool isValidDim(Value value); -/// Returns true if the given Value can be used as a symbol. +/// Returns true if the given Value can be used as a dimension id for an op with +/// a region. +bool isValidDim(Value value, Operation *opWithRegion); + +/// Returns true if the given value can be used as a symbol in the closest +/// op that is isolated from above or an affine graybox op enclosing this +/// value's definition or block argument appearance. bool isValidSymbol(Value value); +/// Returns true if the given Value can be used as a symbol for an +/// op with a region. +bool isValidSymbol(Value value, Operation *opWithRegion); + /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands /// 2. drop unused dims and symbols from map Index: mlir/include/mlir/Dialect/AffineOps/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -347,4 +347,53 @@ let verifier = ?; } +def AffineGrayBoxOp : Affine_Op<"graybox">, + Arguments<(ins Variadic:$operands)> { + let summary = "graybox operation"; + let description = [{ + The affine graybox op introduces a new symbol context for affine + operations. It holds a single region, which can be a list of one or more + blocks. The op's region can have zero or more arguments, each of which can + only be a memref. The operands bind 1:1 to its region's arguments. The op + can't use any memrefs defined outside of it, but can use any other SSA + values that dominate it. Its region's blocks can have terminators the same + way as current MLIR functions (FuncOp) can. Control from any return ops + from the top level of its region returns to right after the affine.graybox + op. Its control flow thus conforms to the control flow semantics of + regions, i.e., control always returns to the immediate enclosing (parent) + op. The results of a graybox op match 1:1 with the return values from its + region's blocks; + + Ex: + + affine.for %i = 0 to 128 { + affine.graybox [%rI, %rM] = (%I, %M) : memref<128xi32>, memref<24xf32> { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + affine.load %rM[%index]: memref<24xf32> + return + } + } + + affine.for %i = 0 to %n { + affine.graybox [] = () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + + }]; + + let regions = (region AnyRegion:$region); + + // TODO: builders. + + // TODO: canonicalizations related to memrefs. + let hasCanonicalizer = 0; +} + #endif // AFFINE_OPS Index: mlir/lib/Dialect/AffineOps/AffineOps.cpp =================================================================== --- mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -98,18 +98,47 @@ return builder.create(loc, type, value); } -/// A utility function to check if a given region is attached to a function. -static bool isFunctionRegion(Region *region) { - return llvm::isa(region->getParentOp()); +/// A utility function to check if a given region is attached to an op isolated +/// from above or an affine graybox. +static bool isIsolatedOrGrayBoxRegion(Region *region) { + return region->getParentOp()->isKnownIsolatedFromAbove() || + isa(region->getParentOp()); } -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine graybox. A value of index type defined +/// at the top level is always a valid symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) - return isFunctionRegion(arg->getOwner()->getParent()); - return isFunctionRegion(value->getDefiningOp()->getParentRegion()); + return isIsolatedOrGrayBoxRegion(arg->getOwner()->getParent()); + return isIsolatedOrGrayBoxRegion(value->getDefiningOp()->getParentRegion()); +} + +/// A utility function to check if a value is defined at the top level of +/// 'opWithRegion'. A value of index type defined at the top level is always a +/// valid symbol. +static bool isTopLevelValue(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + if (auto arg = value.dyn_cast()) + return arg->getOwner()->getParentOp() == opWithRegion; + return value->getDefiningOp()->getParentOp() == opWithRegion; +} + +/// Returns the closest op surrounding 'op' that is either an AffineGrayBoxOp or +/// an op isolated from above (eg. FuncOp). Asserts if called on a top-level +/// op. +// TODO: getAffineScope should be publicly exposed for affine +// passes/utitlies. +static Operation *getAffineScope(Operation *op) { + // TODO: make this compact by introducing a variadic pack on getParentOfType. + auto *curOp = op; + while ((curOp = curOp->getParentOp())) + if (llvm::isa(curOp) || curOp->isKnownIsolatedFromAbove()) + return curOp; + + assert(false && "op doesn't have a parent op"); + return nullptr; } // Value can be used as a dimension id if it is valid as a symbol, or @@ -120,40 +149,65 @@ if (!value->getType().isIndex()) return false; - if (auto *op = value->getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidDim(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelValue(dimOp.getOperand()); + if (auto *op = value->getDefiningOp()) + return isValidDim(value, getAffineScope(op)); + + // This value has to be a block argument for an op isolated from above or an + // affine.for. (A graybox can't have index type arguments.) + auto *parentOp = value.cast()->getOwner()->getParentOp(); + return parentOp->isKnownIsolatedFromAbove() || isa(parentOp); +} + +// Value can be used as a dimension id if it is valid as a symbol, or it is an +// induction variable, or it is a result of affine apply operation with +// dimension id arguments. +bool mlir::isValidDim(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + // The value must be an index type. + if (!value->getType().isIndex()) return false; + + auto *op = value->getDefiningOp(); + if (!op) { + // This value has to be a block argument for a FuncOp or an affine.for. + auto *parentOp = value.cast()->getOwner()->getParentOp(); + return parentOp->isKnownIsolatedFromAbove() || isa(parentOp); } - // This value has to be a block argument for a FuncOp or an affine.for. - auto *parentOp = value.cast()->getOwner()->getParentOp(); - return isa(parentOp) || isa(parentOp); + + // Top level operation or constant operation is ok. + if (::isTopLevelValue(value, opWithRegion) || isa(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(op)) + return applyOp.isValidDim(opWithRegion); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast(op)) + return isTopLevelValue(dimOp.getOperand()); + return false; } /// Returns true if the 'index' dimension of the `memref` defined by -/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol +/// for 'op'. template -bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) { +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, + Operation *op) { + assert(op->getNumRegions() > 0 && "only to be called on ops with regions"); auto memRefType = memrefDefOp.getType(); // Statically shaped. if (!ShapedType::isDynamic(memRefType.getDimSize(index))) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); - return isValidSymbol( - *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); + return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), + op); } /// Returns true if the result of the dim op is a valid symbol. -static bool isDimOpValidSymbol(DimOp dimOp) { +static bool isDimOpValidSymbol(DimOp dimOp, Operation *op) { + assert(op->getNumRegions() > 0 && "only to be called on ops with regions"); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (isTopLevelValue(dimOp.getOperand())) @@ -163,43 +217,81 @@ // whose corresponding size is a valid symbol. unsigned index = dimOp.getIndex(); if (auto viewOp = dyn_cast(dimOp.getOperand()->getDefiningOp())) - return isMemRefSizeValidSymbol(viewOp, index); + return isMemRefSizeValidSymbol(viewOp, index, op); if (auto subViewOp = dyn_cast(dimOp.getOperand()->getDefiningOp())) - return isMemRefSizeValidSymbol(subViewOp, index); + return isMemRefSizeValidSymbol(subViewOp, index, op); if (auto allocOp = dyn_cast(dimOp.getOperand()->getDefiningOp())) - return isMemRefSizeValidSymbol(allocOp, index); + return isMemRefSizeValidSymbol(allocOp, index, op); return false; } // Value can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments, or a result of the dim op on a memref satisfying certain -// constraints. +// the top level of the enclosing affine scope (graybox or func op) or dominates +// such a scope, or it is a result of affine apply operation with symbol +// arguments, or a result of the dim op on a memref whose corresponding size is +// a valid symbol. bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value->getType().isIndex()) return false; - if (auto *op = value->getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidSymbol(); - if (auto dimOp = dyn_cast(op)) { - return isDimOpValidSymbol(dimOp); - } - } - // Otherwise, check that the value is a top level value. - return isTopLevelValue(value); + // Check that the value is a top level value. + if (isTopLevelValue(value)) + return true; + + if (auto *op = value->getDefiningOp()) + return isValidSymbol(value, getAffineScope(op)); + + return false; +} + +// Value can be used as a symbol in the region of an op 'opWithRegion' if it is +// a constant, or it is defined at the top level of 'opWithRegion' or dominates +// 'opWithRegion', or it is the result of an affine apply operation with symbol +// arguments, or a result of the dim op on a memref whose corresponding size is +// a valid symbol. +bool mlir::isValidSymbol(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + + // The value must be an index type. + if (!value->getType().isIndex()) + return false; + + // A top-level value is a valid symbol. + if (::isTopLevelValue(value, opWithRegion)) + return true; + + auto *defOp = value->getDefiningOp(); + if (!defOp) + // A block argument that is not a top-level value isn't a valid symbol. + return false; + + // Constant operation is ok. + if (isa(defOp)) + return true; + + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(defOp)) + return applyOp.isValidSymbol(opWithRegion); + + // Dim op results could be valid symbols at any level. + if (auto dimOp = dyn_cast(defOp)) + return isDimOpValidSymbol(dimOp, opWithRegion); + + // Check for values dominating 'opWithRegion'. + if (auto *parentOp = opWithRegion->getParentOp()) + if (!parentOp->isKnownIsolatedFromAbove()) + return isValidSymbol(value, parentOp); + + return false; } // Returns true if 'value' is a valid index to an affine operation (e.g. -// affine.load, affine.store, affine.dma_start, affine.dma_wait). -// Returns false otherwise. -static bool isValidAffineIndexOperand(Value value) { - return isValidDim(value) || isValidSymbol(value); +// affine.load, affine.store, affine.dma_start, affine.dma_wait) inside the +// region of 'op'. Returns false otherwise. +static bool isValidAffineIndexOperand(Value value, Operation *op) { + return isValidDim(value, op) || isValidSymbol(value, op); } /// Utility function to verify that a set of operands are valid dimension and @@ -300,6 +392,14 @@ [](Value op) { return mlir::isValidDim(op); }); } +// The result of the affine apply operation can be used as a dimension id if all +// its operands are valid dimension ids. +bool AffineApplyOp::isValidDim(Operation *opWithRegion) { + return llvm::all_of(getOperands(), [&](Value op) { + return mlir::isValidDim(op, opWithRegion); + }); +} + // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { @@ -307,6 +407,14 @@ [](Value op) { return mlir::isValidSymbol(op); }); } +// The result of the affine apply operation can be used as a symbol if all its +// operands are symbols. +bool AffineApplyOp::isValidSymbol(Operation *opWithRegion) { + return llvm::all_of(getOperands(), [&](Value operand) { + return mlir::isValidSymbol(operand, opWithRegion); + }); +} + OpFoldResult AffineApplyOp::fold(ArrayRef operands) { auto map = getAffineMap(); @@ -970,22 +1078,23 @@ return emitOpError("incorrect number of operands"); } + auto *scope = getAffineScope(*this); for (auto idx : getSrcIndices()) { if (!idx->getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { if (!idx->getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { if (!idx->getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("tag index must be a dimension or symbol identifier"); } return success(); @@ -1061,7 +1170,8 @@ for (auto idx : getTagIndices()) { if (!idx->getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1818,7 +1928,9 @@ for (auto idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1916,7 +2028,8 @@ for (auto idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to store must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2090,7 +2203,8 @@ } for (auto idx : op.getMapOperands()) { - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(op); + if (!isValidAffineIndexOperand(idx, scope)) return op.emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2108,6 +2222,118 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// AffineGrayBoxOp +//===----------------------------------------------------------------------===// +// + +static LogicalResult verify(AffineGrayBoxOp op) { + // All memref uses in the graybox region should be explicitly captured. + // FIXME: change this walk to an affine walk that doesn't walk inner + // grayboxes. + DenseSet memrefsUsed; + op.region().walk([&](Operation *innerOp) { + for (auto v : innerOp->getOperands()) + if (v->getType().isa()) + memrefsUsed.insert(v); + }); + + // For each memref use, ensure either a graybox argument or locally defined. + for (auto memref : memrefsUsed) { + if (auto arg = memref.dyn_cast()) + if (arg->getOwner()->getParent()->getParentOp() == op) + continue; + if (auto *defOp = memref->getDefiningOp()) + // FIXME: this will only work if the memrefs collected above didn't + // include any from inner grayboxes. + if (defOp->getParentOfType() == op) + continue; + return op.emitOpError("incoming memref not explicitly captured"); + } + return success(); +} + +// Custom form syntax. +// +// (ssa-id `=`)? `affine.graybox` `[` memref-region-arg-list `]` +// `=` `(` memref-use-list `)` +// `:` memref-type-list-parens `->` function-result-type `{` +// block+ +// `}` +// +// Ex: +// +// affine.graybox [%rI, %rM] = (%I, %M) : memref<128xi32>, memref<1024xf32> { +// %idx = affine.load %rI[%i] : memref<128xi32> +// %index = index_cast %idx : i32 to index +// affine.load %rM[%index]: memref<1024xf32> +// return +// } +// +static ParseResult parseAffineGrayBoxOp(OpAsmParser &parser, + OperationState &result) { + // Sizes of the grid and block. + SmallVector memrefs; + + // Region arguments to be created. + SmallVector regionMemRefs; + + auto argLoc = parser.getCurrentLocation(); + + // Parse the memref assignments. + if (parser.parseRegionArgumentList(regionMemRefs, + OpAsmParser::Delimiter::Square) || + parser.parseEqual() || + parser.parseOperandList(memrefs, OpAsmParser::Delimiter::Paren)) + return failure(); + + if (memrefs.size() != regionMemRefs.size()) + return parser.emitError(parser.getNameLoc(), + "incorrect number of memref captures"); + + SmallVector memrefTypes; + if (parser.parseOptionalColonTypeList(memrefTypes)) + return failure(); + + if (parser.resolveOperands(memrefs, memrefTypes, argLoc, result.operands)) + return failure(); + + // Introduce the body region and parse it. The region has + // kNumConfigRegionAttributes leading arguments that correspond to + // block/thread identifiers and grid/block sizes, all of the `index` type. + // Follow the actual kernel arguments. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionMemRefs, memrefTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Set the operands list as resizable so that we can modify operands. + result.setOperandListToResizable(); + return success(); +} + +static void print(OpAsmPrinter &p, AffineGrayBoxOp op) { + p << AffineGrayBoxOp::getOperationName() << " ["; + // TODO: consider shadowing region arguments. + p.printOperands(op.region().front().getArguments()); + p << "] = ("; + auto operands = op.getOperands(); + p.printOperands(operands); + p << ") "; + if (!operands.empty()) + p << ": " << operands.getTypes(); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op.getAttrs()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// Index: mlir/test/AffineOps/graybox.mlir =================================================================== --- /dev/null +++ mlir/test/AffineOps/graybox.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: @arbitrary_bound +func @arbitrary_bound(%n : index) { + affine.for %i = 0 to %n { + affine.graybox [] = () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + // CHECK: affine.graybox [] = () { + // CHECK-NEXT: call @powi + // CHECK-NEXT: affine.for + // CHECK-NEXT: "foo"() + // CHECK-NEXT: } + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +func @powi(index) -> index + +// CHECK-LABEL: func @arbitrary_mem_access +func @arbitrary_mem_access(%I: memref<128xi32>, %M: memref<1024xf32>) { + affine.for %i = 0 to 128 { + // CHECK: affine.graybox [{{.*}}] = ({{.*}}) : memref<128xi32>, memref<1024xf32> + affine.graybox [%rI, %rM] = (%I, %M) : memref<128xi32>, memref<1024xf32> { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + affine.load %rM[%index]: memref<1024xf32> + return + } + } + return +} + +// CHECK-LABEL: @symbol_check +func @symbol_check(%B: memref<100xi32>, %A: memref<100xf32>) { + %cf1 = constant 1.0 : f32 + affine.for %i = 0 to 100 { + %v = affine.load %B[%i] : memref<100xi32> + %vo = index_cast %v : i32 to index + // CHECK: affine.graybox [%{{.*}}] = (%{{.*}}) : memref<100xf32> { + affine.graybox [%rA] = (%A) : memref<100xf32> { + // %vi is now a symbol here. + %vi = index_cast %v : i32 to index + affine.load %rA[%vi] : memref<100xf32> + // %vo is also a symbol (dominates the graybox). + affine.load %rA[%vo] : memref<100xf32> + return + } + // CHECK: index_cast + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.load + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +// CHECK-LABEL: func @search +func @search(%A : memref, %S : memref, %key : i32) { + %ni = dim %A, 0 : memref + %c1 = constant 1 : index + // This loop can be parallelized. + affine.for %i = 0 to %ni { + // CHECK: affine.graybox + affine.graybox [%rA, %rS] = (%A, %S) : memref, memref { + %c0 = constant 0 : index + %nj = dim %rA, 1 : memref + br ^bb1(%c0 : index) + + ^bb1(%j: index): + %p1 = cmpi "slt", %j, %nj : index + cond_br %p1, ^bb2(%j : index), ^bb5 + + ^bb2(%j_arg : index): + %v = affine.load %rA[%i, %j_arg] : memref + %p2 = cmpi "eq", %v, %key : i32 + cond_br %p2, ^bb3(%j_arg : index), ^bb4(%j_arg : index) + + ^bb3(%j_arg2: index): + %j_int = index_cast %j_arg2 : index to i32 + affine.store %j_int, %rS[%i] : memref + br ^bb5 + + ^bb4(%j_arg3 : index): + %jinc = addi %j_arg3, %c1 : index + br ^bb1(%jinc : index) + + ^bb5: + return + } + } + return +} Index: mlir/test/AffineOps/ops.mlir =================================================================== --- mlir/test/AffineOps/ops.mlir +++ mlir/test/AffineOps/ops.mlir @@ -99,3 +99,21 @@ } return } + +// ----- + +// Test symbol restrictions with ops isolated from above. + +// CHECK-LABEL: func @valid_symbol_isolated_region +func @valid_symbol_isolated_region(%n : index) { + test.isolated_region %n { + %c1 = constant 1 : index + %l = subi %n, %c1 : index + // %l, %n are valid symbols since test.isolated_region is known to be + // isolated from above. + affine.for %i = %l to %n { + } + return + } + return +}