diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md --- a/mlir/docs/Dialects/Affine.md +++ b/mlir/docs/Dialects/Affine.md @@ -1,4 +1,4 @@ -# 'affine' Dialect +# `affine` Dialect This dialect provides a powerful abstraction for affine operations and analyses. @@ -60,20 +60,25 @@ ### Restrictions on Dimensions and Symbols The affine dialect imposes certain restrictions on dimension and symbolic -identifiers to enable powerful analysis and transformation. A symbolic -identifier can be bound to an SSA value that is either an argument to the -function, a value defined at the top level of that function (outside of all -loops and if operations), the result of a -[`constant` operation](Standard.md#constant-operation), or the result of an -[`affine.apply` operation](#affineapply-operation) that recursively takes as -arguments any symbolic identifiers, or the result of a [`dim` +identifiers to enable powerful analysis and transformation. An SSA value is a +valid symbol if it is either (1) a region argument for an op that is either +"isolated from above" (like the FuncOp) or is an affine graybox op, (2) a value +defined at the top level of (outside of all loops, if operations, or other +operations with regions) of an affine graybox op or an op "isolated from above", +(3) a value that dominates the closest enclosing affine graybox or an op +"isolated from above", (4) the result of a [`constant` +operation](Standard.md#constant-operation), (4) the result of an [`affine.apply` +operation](#affineapply-operation) that recursively takes as arguments any +symbolic identifiers, or (5) the result of a [`dim` operation](Standard.md#dim-operation) on either a memref that is a function argument or a memref where the corresponding dimension is either static or a -dynamic one in turn bound to a symbolic identifier. Dimensions may be bound not -only to anything that a symbol is bound to, but also to induction variables of -enclosing [`affine.for` operations](#affinefor-operation), and the result of an -[`affine.apply` operation](#affineapply-operation) (which recursively may use -other dimensions and symbols). +dynamic one in turn bound to a symbolic identifier. Note that as a result of +(3), symbol validity is sensitive to the location at which the value binds to +the symbol. Dimensions may be bound not only to anything that a symbol is bound +to, but also to induction variables of enclosing [`affine.for` +operations](#affinefor-operation), and the result of an [`affine.apply` +operation](#affineapply-operation) (which recursively may use other dimensions +and symbols). ### Affine Expressions diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -31,9 +31,9 @@ class FlatAffineConstraints; class OpBuilder; -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine graybox. A value of index type defined +/// at the top level is always a valid symbol. bool isTopLevelValue(Value value); /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data @@ -457,12 +457,24 @@ SmallVectorImpl &results); }; -/// Returns true if the given Value can be used as a dimension id. +/// Returns true if the given Value can be used as a dimension id in the closest +/// op that is isolated from above or an affine graybox enclosing this value's +/// definition or block argument appearance. bool isValidDim(Value value); -/// Returns true if the given Value can be used as a symbol. +/// Returns true if the given Value can be used as a dimension id for an op with +/// a region. +bool isValidDim(Value value, Operation *opWithRegion); + +/// Returns true if the given value can be used as a symbol in the closest +/// op that is isolated from above or an affine graybox op enclosing this +/// value's definition or block argument appearance. bool isValidSymbol(Value value); +/// Returns true if the given Value can be used as a symbol for an +/// op with a region. +bool isValidSymbol(Value value, Operation *opWithRegion); + /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands /// 2. drop unused dims and symbols from map diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -83,12 +83,21 @@ /// Returns the affine value map computed from this operation. AffineValueMap getAffineValueMap(); - /// Returns true if the result of this operation can be used as dimension id. + /// Returns true if the result of this operation can be used as dimension + /// id. bool isValidDim(); + /// Returns true if the result of this operation can be used as dimension id + /// within the region of the op 'opWithRegion'. + bool isValidDim(Operation *opWithRegion); + /// Returns true if the result of this operation is a symbol. bool isValidSymbol(); + /// Returns true if the result of this operation is a symbol in the region + /// of 'opWithRegion'. + bool isValidSymbol(Operation * opWithRegion); + operand_range getMapOperands() { return getOperands(); } }]; @@ -605,4 +614,65 @@ let verifier = ?; } +def AffineGrayBoxOp : Affine_Op<"graybox"> { + let summary = "graybox operation"; + let description = [{ + The affine graybox op introduces a new symbol context for affine + operations. It holds a single region, which can be a list of one or more + blocks. The op's region can have zero or more arguments, each of which can + only be a memref. The operands bind 1:1 to its region's arguments. The op + can't use any memrefs defined outside of it, but can use any other SSA + values that dominate it. Its region's blocks can have terminators the same + way as current MLIR functions (FuncOp) can. Control from any return ops + from the top level of its region returns to right after the affine.graybox + op. Its control flow thus conforms to the control flow semantics of + regions, i.e., control always returns to the immediate enclosing (parent) + op. The results of a graybox op match 1:1 with the return values from its + region's blocks; + + Example: + + ```mlir + affine.for %i = 0 to 128 { + affine.graybox [%rI, %rM] = (%I, %M) + : (memref<128xi32>, memref<24xf32>) -> () { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + affine.load %rM[%index]: memref<24xf32> + return + } + } + ``` + + ```mlir + affine.for %i = 0 to %n { + affine.graybox [] = () : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + } + ``` + + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "ArrayRef memrefs"> + ]; + + // TODO: canonicalizations related to memrefs. + let hasCanonicalizer = 0; +} + #endif // AFFINE_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -84,18 +84,47 @@ return builder.create(loc, type, value); } -/// A utility function to check if a given region is attached to a function. -static bool isFunctionRegion(Region *region) { - return llvm::isa(region->getParentOp()); +/// A utility function to check if a given region is attached to an op isolated +/// from above or an affine graybox. +static bool isIsolatedOrGrayBoxRegion(Region *region) { + return region->getParentOp()->isKnownIsolatedFromAbove() || + isa(region->getParentOp()); } -/// A utility function to check if a value is defined at the top level of a -/// function. A value of index type defined at the top level is always a valid -/// symbol. +/// A utility function to check if a value is defined at the top level of an +/// op isolated from above or an affine graybox. A value of index type defined +/// at the top level is always a valid symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) - return isFunctionRegion(arg.getOwner()->getParent()); - return isFunctionRegion(value.getDefiningOp()->getParentRegion()); + return isIsolatedOrGrayBoxRegion(arg.getOwner()->getParent()); + return isIsolatedOrGrayBoxRegion(value.getDefiningOp()->getParentRegion()); +} + +/// A utility function to check if a value is defined at the top level of +/// 'opWithRegion'. A value of index type defined at the top level is always a +/// valid symbol. +static bool isTopLevelValue(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + if (auto arg = value.dyn_cast()) + return arg.getOwner()->getParentOp() == opWithRegion; + return value.getDefiningOp()->getParentOp() == opWithRegion; +} + +/// Returns the closest op surrounding 'op' that is either an AffineGrayBoxOp or +/// an op isolated from above (eg. FuncOp). Asserts if called on a top-level +/// op. +// TODO: getAffineScope should be publicly exposed for affine +// passes/utitlies. +static Operation *getAffineScope(Operation *op) { + // TODO: make this compact by introducing a variadic pack on getParentOfType. + auto *curOp = op; + while ((curOp = curOp->getParentOp())) + if (llvm::isa(curOp) || curOp->isKnownIsolatedFromAbove()) + return curOp; + + assert(false && "op doesn't have a parent op"); + return nullptr; } // Value can be used as a dimension id if it is valid as a symbol, or @@ -106,43 +135,68 @@ if (!value.getType().isIndex()) return false; - if (auto *op = value.getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidDim(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelValue(dimOp.getOperand()); - return false; - } - // This value has to be a block argument of a FuncOp, an 'affine.for', or an - // 'affine.parallel'. + if (auto *op = value.getDefiningOp()) + return isValidDim(value, getAffineScope(op)); + + // This value has to be a block argument for an op isolated from above or an + // affine.for. (A graybox can't have index type arguments.) auto *parentOp = value.cast().getOwner()->getParentOp(); - return isa(parentOp) || isa(parentOp) || + return parentOp->isKnownIsolatedFromAbove() || isa(parentOp) || isa(parentOp); } +// Value can be used as a dimension id if it is valid as a symbol, or it is an +// induction variable, or it is a result of affine apply operation with +// dimension id arguments. +bool mlir::isValidDim(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + // The value must be an index type. + if (!value.getType().isIndex()) + return false; + + auto *op = value.getDefiningOp(); + if (!op) { + // This value has to be a block argument for a FuncOp, affine.for, + // or an affine.parallel. + auto *parentOp = value.cast().getOwner()->getParentOp(); + return parentOp->isKnownIsolatedFromAbove() || isa(parentOp) || + isa(parentOp); + } + + // Top level operation or constant operation is ok. + if (::isTopLevelValue(value, opWithRegion) || isa(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(op)) + return applyOp.isValidDim(opWithRegion); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast(op)) + return isTopLevelValue(dimOp.getOperand()); + return false; +} + /// Returns true if the 'index' dimension of the `memref` defined by -/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol +/// for 'op'. template -static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, - unsigned index) { +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, + Operation *op) { + assert(op->getNumRegions() > 0 && "only to be called on ops with regions"); auto memRefType = memrefDefOp.getType(); // Statically shaped. if (!ShapedType::isDynamic(memRefType.getDimSize(index))) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); - return isValidSymbol( - *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); + return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), + op); } /// Returns true if the result of the dim op is a valid symbol. -static bool isDimOpValidSymbol(DimOp dimOp) { +static bool isDimOpValidSymbol(DimOp dimOp, Operation *op) { + assert(op->getNumRegions() > 0 && "only to be called on ops with regions"); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (isTopLevelValue(dimOp.getOperand())) @@ -152,43 +206,81 @@ // whose corresponding size is a valid symbol. unsigned index = dimOp.getIndex(); if (auto viewOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(viewOp, index); + return isMemRefSizeValidSymbol(viewOp, index, op); if (auto subViewOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(subViewOp, index); + return isMemRefSizeValidSymbol(subViewOp, index, op); if (auto allocOp = dyn_cast(dimOp.getOperand().getDefiningOp())) - return isMemRefSizeValidSymbol(allocOp, index); + return isMemRefSizeValidSymbol(allocOp, index, op); return false; } // Value can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments, or a result of the dim op on a memref satisfying certain -// constraints. +// the top level of the enclosing affine scope (graybox or func op) or dominates +// such a scope, or it is a result of affine apply operation with symbol +// arguments, or a result of the dim op on a memref whose corresponding size is +// a valid symbol. bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; - if (auto *op = value.getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidSymbol(); - if (auto dimOp = dyn_cast(op)) { - return isDimOpValidSymbol(dimOp); - } - } - // Otherwise, check that the value is a top level value. - return isTopLevelValue(value); + // Check that the value is a top level value. + if (isTopLevelValue(value)) + return true; + + if (auto *op = value.getDefiningOp()) + return isValidSymbol(value, getAffineScope(op)); + + return false; +} + +// Value can be used as a symbol in the region of an op 'opWithRegion' if it is +// a constant, or it is defined at the top level of 'opWithRegion' or dominates +// 'opWithRegion', or it is the result of an affine apply operation with symbol +// arguments, or a result of the dim op on a memref whose corresponding size is +// a valid symbol. +bool mlir::isValidSymbol(Value value, Operation *opWithRegion) { + assert(opWithRegion->getNumRegions() > 0 && + "only to be called on ops with regions"); + + // The value must be an index type. + if (!value.getType().isIndex()) + return false; + + // A top-level value is a valid symbol. + if (::isTopLevelValue(value, opWithRegion)) + return true; + + auto *defOp = value.getDefiningOp(); + if (!defOp) + // A block argument that is not a top-level value isn't a valid symbol. + return false; + + // Constant operation is ok. + if (isa(defOp)) + return true; + + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(defOp)) + return applyOp.isValidSymbol(opWithRegion); + + // Dim op results could be valid symbols at any level. + if (auto dimOp = dyn_cast(defOp)) + return isDimOpValidSymbol(dimOp, opWithRegion); + + // Check for values dominating 'opWithRegion'. + if (auto *parentOp = opWithRegion->getParentOp()) + if (!parentOp->isKnownIsolatedFromAbove()) + return isValidSymbol(value, parentOp); + + return false; } // Returns true if 'value' is a valid index to an affine operation (e.g. -// affine.load, affine.store, affine.dma_start, affine.dma_wait). -// Returns false otherwise. -static bool isValidAffineIndexOperand(Value value) { - return isValidDim(value) || isValidSymbol(value); +// affine.load, affine.store, affine.dma_start, affine.dma_wait) inside the +// region of 'op'. Returns false otherwise. +static bool isValidAffineIndexOperand(Value value, Operation *op) { + return isValidDim(value, op) || isValidSymbol(value, op); } /// Utility function to verify that a set of operands are valid dimension and @@ -273,6 +365,14 @@ [](Value op) { return mlir::isValidDim(op); }); } +// The result of the affine apply operation can be used as a dimension id if all +// its operands are valid dimension ids. +bool AffineApplyOp::isValidDim(Operation *opWithRegion) { + return llvm::all_of(getOperands(), [&](Value op) { + return mlir::isValidDim(op, opWithRegion); + }); +} + // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { @@ -280,6 +380,14 @@ [](Value op) { return mlir::isValidSymbol(op); }); } +// The result of the affine apply operation can be used as a symbol if all its +// operands are symbols. +bool AffineApplyOp::isValidSymbol(Operation *opWithRegion) { + return llvm::all_of(getOperands(), [&](Value operand) { + return mlir::isValidSymbol(operand, opWithRegion); + }); +} + OpFoldResult AffineApplyOp::fold(ArrayRef operands) { auto map = getAffineMap(); @@ -947,22 +1055,23 @@ return emitOpError("incorrect number of operands"); } + auto *scope = getAffineScope(*this); for (auto idx : getSrcIndices()) { if (!idx.getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { if (!idx.getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("tag index must be a dimension or symbol identifier"); } return success(); @@ -1038,7 +1147,8 @@ for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1794,7 +1904,9 @@ for (auto idx : getMapOperands()) { if (!idx.getType().isIndex()) return emitOpError("index to load must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -1892,7 +2004,8 @@ for (auto idx : getMapOperands()) { if (!idx.getType().isIndex()) return emitOpError("index to store must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(*this); + if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2112,7 +2225,8 @@ } for (auto idx : op.getMapOperands()) { - if (!isValidAffineIndexOperand(idx)) + auto *scope = getAffineScope(op); + if (!isValidAffineIndexOperand(idx, scope)) return op.emitOpError("index must be a dimension or symbol identifier"); } return success(); @@ -2130,6 +2244,155 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// AffineGrayBoxOp +//===----------------------------------------------------------------------===// +// + +// TODO: missing region body. +void AffineGrayBoxOp::build(Builder *builder, OperationState &result, + ArrayRef memrefs) { + // Create a region and an empty entry block. The arguments of the region are + // the supplied memrefs. + Region *region = result.addRegion(); + Block *body = new Block(); + region->push_back(body); + + SmallVector memrefTypes; + memrefTypes.reserve(memrefs.size()); + for (auto v : memrefs) { + memrefTypes.push_back(v.getType()); + } + body->addArguments(memrefTypes); + region->push_back(body); + + // Set the operands list as resizable so that we can add memrefs. + result.setOperandListToResizable(); +} + +static LogicalResult verify(AffineGrayBoxOp op) { + // All memref uses in the graybox region should be explicitly captured. + // FIXME: change this walk to an affine walk that doesn't walk inner + // grayboxes. + DenseSet memrefsUsed; + op.region().walk([&](Operation *innerOp) { + for (auto v : innerOp->getOperands()) + if (v.getType().isa()) + memrefsUsed.insert(v); + }); + + // For each memref use, ensure either a graybox argument or a local def. + for (auto memref : memrefsUsed) { + if (auto arg = memref.dyn_cast()) + if (arg.getOwner()->getParent()->getParentOp() == op) + continue; + if (auto *defOp = memref.getDefiningOp()) + // FIXME: this will only work if the memrefs collected above didn't + // include any from inner grayboxes. + if (defOp->getParentOfType() == op) + continue; + return op.emitOpError("incoming memref not explicitly captured"); + } + + // Verify that the region arguments match operands. + auto &entryBlock = op.region().front(); + if (entryBlock.getNumArguments() != op.getNumOperands()) + return op.emitOpError("region argument count does not match operand count"); + + for (auto &argEn : llvm::enumerate(entryBlock.getArguments())) { + if (op.getOperand(argEn.index()).getType() != argEn.value().getType()) + return op.emitOpError("type of one or more region arguments does not " + "match corresponding operand"); + } + + return success(); +} + +// Custom form syntax. +// +// (ssa-id `=`)? `affine.graybox` `[` memref-region-arg-list `]` +// `=` `(` memref-use-list `)` +// `:` memref-type-list-parens `->` function-result-type `{` +// block+ +// `}` +// +// Ex: +// +// affine.graybox [%rI, %rM] = (%I, %M) +// : (memref<128xi32>, memref<1024xf32>) -> () { +// %idx = affine.load %rI[%i] : memref<128xi32> +// %index = index_cast %idx : i32 to index +// affine.load %rM[%index]: memref<1024xf32> +// return +// } +// +static ParseResult parseAffineGrayBoxOp(OpAsmParser &parser, + OperationState &result) { + // Memref operands. + SmallVector memrefs; + + // Region arguments to be created. + SmallVector regionMemRefs; + + // The graybox op has the same type signature as a function. + FunctionType opType; + + // Parse the memref assignments. + auto argLoc = parser.getCurrentLocation(); + if (parser.parseRegionArgumentList(regionMemRefs, + OpAsmParser::Delimiter::Square) || + parser.parseEqual() || + parser.parseOperandList(memrefs, OpAsmParser::Delimiter::Paren)) + return failure(); + + if (memrefs.size() != regionMemRefs.size()) + return parser.emitError(parser.getNameLoc(), + "incorrect number of memref captures"); + + if (parser.parseColonType(opType) || + parser.addTypesToList(opType.getResults(), result.types)) + return failure(); + + auto memrefTypes = opType.getInputs(); + if (parser.resolveOperands(memrefs, memrefTypes, argLoc, result.operands)) + return failure(); + + // Introduce the body region and parse it. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionMemRefs, memrefTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Set the operands list as resizable so that we can modify operands. + result.setOperandListToResizable(); + return success(); +} + +static void print(OpAsmPrinter &p, AffineGrayBoxOp op) { + p << AffineGrayBoxOp::getOperationName() << " ["; + // TODO: consider shadowing region arguments. + p.printOperands(op.region().front().getArguments()); + p << "] = ("; + auto operands = op.getOperands(); + p.printOperands(operands); + p << ") "; + + SmallVector argTypes(op.getOperandTypes()); + p << " : " + << FunctionType::get(argTypes, op.getResultTypes(), + op.getOperation()->getContext()); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op.getAttrs()); +} + //===----------------------------------------------------------------------===// // AffineParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/graybox.mlir b/mlir/test/Dialect/Affine/graybox.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/graybox.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: @arbitrary_bound +func @arbitrary_bound(%n : index) { + affine.for %i = 0 to %n { + affine.graybox [] = () : () -> () { + // %pow can now be used as a loop bound. + %pow = call @powi(%i) : (index) -> index + affine.for %j = 0 to %pow { + "foo"() : () -> () + } + return + } + // CHECK: affine.graybox [] = () : () -> () { + // CHECK-NEXT: call @powi + // CHECK-NEXT: affine.for + // CHECK-NEXT: "foo"() + // CHECK-NEXT: } + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +func @powi(index) -> index + +// CHECK-LABEL: func @arbitrary_mem_access +func @arbitrary_mem_access(%I: memref<128xi32>, %M: memref<1024xf32>) { + affine.for %i = 0 to 128 { + // CHECK: %{{.*}} = affine.graybox [{{.*}}] = ({{.*}}) : (memref<128xi32>, memref<1024xf32>) -> f32 + affine.graybox [%rI, %rM] = (%I, %M) : (memref<128xi32>, memref<1024xf32>) -> f32 { + %idx = affine.load %rI[%i] : memref<128xi32> + %index = index_cast %idx : i32 to index + %v = affine.load %rM[%index]: memref<1024xf32> + return %v : f32 + } + } + return +} + +// CHECK-LABEL: @symbol_check +func @symbol_check(%B: memref<100xi32>, %A: memref<100xf32>) { + %cf1 = constant 1.0 : f32 + affine.for %i = 0 to 100 { + %v = affine.load %B[%i] : memref<100xi32> + %vo = index_cast %v : i32 to index + // CHECK: affine.graybox [%{{.*}}] = (%{{.*}}) : (memref<100xf32>) -> () { + affine.graybox [%rA] = (%A) : (memref<100xf32>) -> () { + // %vi is now a symbol here. + %vi = index_cast %v : i32 to index + affine.load %rA[%vi] : memref<100xf32> + // %vo is also a symbol (dominates the graybox). + affine.load %rA[%vo] : memref<100xf32> + return + } + // CHECK: index_cast + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.load + // CHECK-NEXT: return + // CHECK-NEXT: } + } + return +} + +// CHECK-LABEL: func @search +func @search(%A : memref, %S : memref, %key : i32) { + %ni = dim %A, 0 : memref + %c1 = constant 1 : index + // This loop can be parallelized. + affine.for %i = 0 to %ni { + // CHECK: affine.graybox + affine.graybox [%rA, %rS] = (%A, %S) : (memref, memref) -> () { + %c0 = constant 0 : index + %nj = dim %rA, 1 : memref + br ^bb1(%c0 : index) + + ^bb1(%j: index): + %p1 = cmpi "slt", %j, %nj : index + cond_br %p1, ^bb2(%j : index), ^bb5 + + ^bb2(%j_arg : index): + %v = affine.load %rA[%i, %j_arg] : memref + %p2 = cmpi "eq", %v, %key : i32 + cond_br %p2, ^bb3(%j_arg : index), ^bb4(%j_arg : index) + + ^bb3(%j_arg2: index): + %j_int = index_cast %j_arg2 : index to i32 + affine.store %j_int, %rS[%i] : memref + br ^bb5 + + ^bb4(%j_arg3 : index): + %jinc = addi %j_arg3, %c1 : index + br ^bb1(%jinc : index) + + ^bb5: + return + } + } + return +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -171,6 +171,62 @@ // ----- +// CHECK-LABEL: @affine_graybox_missing_capture +func @affine_graybox_missing_capture(%M : memref<2xi32>) { + affine.for %i = 0 to 10 { + affine.graybox [] = () : () -> () { + // expected-error@-1 {{incoming memref not explicitly captured}} + affine.load %M[%i] : memref<2xi32> + } + } + return +} + +// ----- + +// CHECK-LABEL: @affine_graybox_wrong_capture +func @affine_graybox_wrong_capture(%s : index) { + affine.graybox [%rS] = (%s) : (index) -> () { + // expected-error@-1 {{operand #0 must be memref}} + "use"(%s) : (index) -> () + } +} + +// ----- + +// CHECK-LABEL: @affine_graybox_wrong_capture +func @affine_graybox_wrong_capture(%A : memref<2xi32>) { + affine.graybox [] = (%A) : (memref<2xi32>) -> () { + // expected-error@-1 {{incorrect number of memref captures}} + } + return +} + +// ----- + +// CHECK-LABEL: @affine_graybox_region_type_mismatch +func @affine_graybox_region_type_mismatch(%A : memref<2xi32>) { + "affine.graybox"(%A) ({ + // expected-error@-1 {{type of one or more region arguments does not match corresponding operand}} + ^bb0(%rA : memref<4xi32>): + return + }) : (memref<2xi32>) -> () +} + +// ----- + +// CHECK-LABEL: @affine_graybox_region_arg_count_mismatch +func @affine_graybox_region_arg_count_mismatch(%A : memref<2xi32>) { + "affine.graybox"(%A) ({ + // expected-error@-1 {{region argument count does not match operand count}} + ^bb0: + return + }) : (memref<2xi32>) -> () + return +} + +// ----- + // CHECK-LABEL: @affine_max func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) { // expected-error@+1 {{operand count and affine map dimension and symbol count must match}} diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -115,6 +115,24 @@ // ----- +// Test symbol restrictions with ops isolated from above. + +// CHECK-LABEL: func @valid_symbol_isolated_region +func @valid_symbol_isolated_region(%n : index) { + test.isolated_region %n { + %c1 = constant 1 : index + %l = subi %n, %c1 : index + // %l, %n are valid symbols since test.isolated_region is known to be + // isolated from above. + affine.for %i = %l to %n { + } + return + } + return +} + +// ----- + // CHECK-LABEL: @parallel // CHECK-SAME: (%[[N:.*]]: index) func @parallel(%N : index) {