diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -313,16 +313,6 @@ computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape); -/// Prints dimension and symbol list. -void printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, unsigned numDims, - OpAsmPrinter &p); - -/// Parses dimension and symbol list and returns true if parsing failed. -ParseResult parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, - unsigned &numDims); - /// Determines whether MemRefCastOp casts to a more dynamic version of the /// source memref. This is useful to to fold a memref_cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -157,20 +157,38 @@ class AllocLikeOp traits = []> : - Std_Op]>], traits)> { - - let arguments = (ins Variadic:$value, + Std_Op]>, + AttrSizedOperandSegments + ], traits)> { + + let arguments = (ins Variadic:$dynamicSizes, + // The symbolic operands (the ones in square brackets) bind + // to the symbols of the memref's layout map. + Variadic:$symbolOperands, Confined, [IntMinValue<0>]>:$alignment); - let results = (outs Res]>); + let results = (outs Res]>:$memref); let builders = [ - OpBuilderDAG<(ins "MemRefType":$memrefType), [{ - $_state.types.push_back(memrefType); + OpBuilderDAG<(ins "MemRefType":$memrefType, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, {}, alignment); }]>, - OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$operands, - CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ - $_state.addOperands(operands); + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, dynamicSizes, {}, alignment); + }]>, + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + "ValueRange":$symbolOperands, + CArg<"IntegerAttr", "{}">:$alignment), [{ $_state.types.push_back(memrefType); + $_state.addOperands(dynamicSizes); + $_state.addOperands(symbolOperands); + $_state.addAttribute(getOperandSegmentSizeAttr(), + $_builder.getI32VectorAttr({ + static_cast(dynamicSizes.size()), + static_cast(symbolOperands.size())})); if (alignment) $_state.addAttribute(getAlignmentAttrName(), alignment); }]>]; @@ -180,23 +198,13 @@ MemRefType getType() { return getResult().getType().cast(); } - /// Returns the number of symbolic operands (the ones in square brackets), - /// which bind to the symbols of the memref's layout map. - unsigned getNumSymbolicOperands() { - return getNumOperands() - getType().getNumDynamicDims(); - } - - /// Returns the symbolic operands (the ones in square brackets), which bind - /// to the symbols of the memref's layout map. - operand_range getSymbolicOperands() { - return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; - } - /// Returns the dynamic sizes for this alloc operation if specified. - operand_range getDynamicSizes() { return getOperands(); } + operand_range getDynamicSizes() { return dynamicSizes(); } }]; - let parser = [{ return ::parseAllocLikeOp(parser, result); }]; + let assemblyFormat = [{ + `(`$dynamicSizes`)` (`[` $symbolOperands^ `]`)? attr-dict `:` type($memref) + }]; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -297,6 +297,33 @@ return isValidDim(value, region) || isValidSymbol(value, region); } +/// Prints dimension and symbol list. +static void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter &printer) { + OperandRange operands(begin, end); + printer << '(' << operands.take_front(numDims) << ')'; + if (operands.size() > numDims) + printer << '[' << operands.drop_front(numDims) << ']'; +} + +/// Parses dimension and symbol list and returns true if parsing failed. +static ParseResult parseDimAndSymbolList(OpAsmParser &parser, + SmallVectorImpl &operands, + unsigned &numDims) { + SmallVector opInfos; + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + return failure(); + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto indexTy = parser.getBuilder().getIndexType(); + return failure(parser.parseOperandList( + opInfos, OpAsmParser::Delimiter::OptionalSquare) || + parser.resolveOperands(opInfos, indexTy, operands)); +} + /// Utility function to verify that a set of operands are valid dimension and /// symbol identifiers. The operands should be laid out such that the dimension /// operands are before the symbol operands. This function returns failure if diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -169,36 +169,6 @@ return builder.create(loc, type, value); } -void mlir::printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, - unsigned numDims, OpAsmPrinter &p) { - Operation::operand_range operands(begin, end); - p << '(' << operands.take_front(numDims) << ')'; - if (operands.size() != numDims) - p << '[' << operands.drop_front(numDims) << ']'; -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, - unsigned &numDims) { - SmallVector opInfos; - if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) - return failure(); - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(opInfos, - OpAsmParser::Delimiter::OptionalSquare) || - parser.resolveOperands(opInfos, indexTy, operands)) - return failure(); - return success(); -} - /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. @@ -392,90 +362,37 @@ //===----------------------------------------------------------------------===// template -static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) { - static_assert(llvm::is_one_of::value, - "applies to only alloc or alloca"); - p << name; - - // Print dynamic dimension operands. - MemRefType type = op.getType(); - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - type.getNumDynamicDims(), p); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); - p << " : " << type; -} - -static void print(OpAsmPrinter &p, AllocOp op) { - printAllocLikeOp(p, op, "alloc"); -} - -static void print(OpAsmPrinter &p, AllocaOp op) { - printAllocLikeOp(p, op, "alloca"); -} - -static ParseResult parseAllocLikeOp(OpAsmParser &parser, - OperationState &result) { - MemRefType type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) - return parser.emitError(parser.getNameLoc()) - << "dimension operand count does not equal memref dynamic dimension " - "count"; - result.types.push_back(type); - return success(); -} - -template -static LogicalResult verify(AllocLikeOp op) { +static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); auto memRefType = op.getResult().getType().template dyn_cast(); if (!memRefType) return op.emitOpError("result must be a memref"); - unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) { - // Store number of symbols used in affine map (used in subsequent check). - AffineMap affineMap = memRefType.getAffineMaps()[0]; - numSymbols = affineMap.getNumSymbols(); - } + if (static_cast(op.dynamicSizes().size()) != + memRefType.getNumDynamicDims()) + return op.emitOpError("dimension operand count does not equal memref " + "dynamic dimension count"); - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - unsigned numDynamicDims = memRefType.getNumDynamicDims(); - if (op.getNumOperands() != numDynamicDims + numSymbols) + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) + numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (op.symbolOperands().size() != numSymbols) return op.emitOpError( - "operand count does not equal dimension plus symbol operand count"); + "symbol operand count does not equal memref symbol count"); - // Verify that all operands are of type Index. - for (auto operandType : op.getOperandTypes()) - if (!operandType.isIndex()) - return op.emitOpError("requires operands to be of type Index"); + return success(); +} - if (std::is_same::value) - return success(); +static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } +static LogicalResult verify(AllocaOp op) { // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op.template getParentWithTrait()) + if (!op.getParentWithTrait()) return op.emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); - return success(); + return verifyAllocLikeOp(op); } namespace { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -401,7 +401,7 @@ // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = - normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size()); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -409,9 +409,9 @@ Value oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); + SmallVector symbolOperands(allocOp.symbolOperands()); AllocOp newAlloc = b.create(allocOp.getLoc(), newMemRefType, - llvm::None, allocOp.alignmentAttr()); + allocOp.alignmentAttr()); AffineMap layoutMap = memrefType.getAffineMaps().front(); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -87,7 +87,8 @@ ^bb0: %0 = constant 7 : index // Test alloc with wrong number of dynamic dimensions. - %1 = alloc(%0)[%1] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{op 'std.alloc' dimension operand count does not equal memref dynamic dimension count}} + // expected-error@+1 {{dimension operand count does not equal memref dynamic dimension count}} + %1 = alloc(%0) [%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } @@ -97,7 +98,8 @@ ^bb0: %0 = constant 7 : index // Test alloc with wrong number of symbols - %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{operand count does not equal dimension plus symbol operand count}} + // expected-error@+1 {{symbol operand count does not equal memref symbol count}} + %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } diff --git a/mlir/test/IR/memory-ops.mlir b/mlir/test/IR/memory-ops.mlir --- a/mlir/test/IR/memory-ops.mlir +++ b/mlir/test/IR/memory-ops.mlir @@ -17,12 +17,12 @@ %1 = alloc(%c0, %c1) : memref (d0, d1)>, 1> // Test alloc with no dynamic dimensions and one symbol. - // CHECK: %2 = alloc()[%c0] : memref<2x4xf32, #map0, 1> - %2 = alloc()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + // CHECK: %2 = alloc() [%c0] : memref<2x4xf32, #map0, 1> + %2 = alloc() [%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // Test alloc with dynamic dimensions and one symbol. - // CHECK: %3 = alloc(%c1)[%c0] : memref<2x?xf32, #map0, 1> - %3 = alloc(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> + // CHECK: %3 = alloc(%c1) [%c0] : memref<2x?xf32, #map0, 1> + %3 = alloc(%c1) [%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> // Alloc with no mappings. // b/116054838 Parser crash while parsing ill-formed AllocOp @@ -48,12 +48,12 @@ %1 = alloca(%c0, %c1) : memref (d0, d1)>, 1> // Test alloca with no dynamic dimensions and one symbol. - // CHECK: %2 = alloca()[%c0] : memref<2x4xf32, #map0, 1> - %2 = alloca()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + // CHECK: %2 = alloca() [%c0] : memref<2x4xf32, #map0, 1> + %2 = alloca() [%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // Test alloca with dynamic dimensions and one symbol. - // CHECK: %3 = alloca(%c1)[%c0] : memref<2x?xf32, #map0, 1> - %3 = alloca(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> + // CHECK: %3 = alloca(%c1) [%c0] : memref<2x?xf32, #map0, 1> + %3 = alloca(%c1) [%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> // Alloca with no mappings, but with alignment. // CHECK: %4 = alloca() {alignment = 64 : i64} : memref<2xi32>