diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -46,206 +46,4 @@ #define GET_OP_CLASSES #include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc" -namespace mlir { -namespace memref { -// DmaStartOp starts a non-blocking DMA operation that transfers data from a -// source memref to a destination memref. The source and destination memref need -// not be of the same dimensionality, but need to have the same elemental type. -// The operands include the source and destination memref's each followed by its -// indices, size of the data transfer in terms of the number of elements (of the -// elemental type of the memref), a tag memref with its indices, and optionally -// at the end, a stride and a number_of_elements_per_stride arguments. The tag -// location is used by a DmaWaitOp to check for completion. The indices of the -// source memref, destination memref, and the tag memref have the same -// restrictions as any load/store. The optional stride arguments should be of -// 'index' type, and specify a stride for the slower memory space (memory space -// with a lower memory space id), transferring chunks of -// number_of_elements_per_stride every stride until %num_elements are -// transferred. Either both or no stride arguments should be specified. If the -// source and destination locations overlap the behavior of this operation is -// not defined. -// -// For example, a DmaStartOp operation that transfers 256 elements of a memref -// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space -// 1 at indices [%k, %l], would be specified as follows: -// -// %num_elements = constant 256 -// %idx = constant 0 : index -// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : -// memref<40 x 128 x f32>, (d0) -> (d0), 0>, -// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, -// memref<1 x i32>, (d0) -> (d0), 2> -// -// If %stride and %num_elt_per_stride are specified, the DMA is expected to -// transfer %num_elt_per_stride elements every %stride elements apart from -// memory space 0 until %num_elements are transferred. -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, -// %num_elt_per_stride : -// -// TODO: add additional operands to allow source and destination striding, and -// multiple stride levels. -// TODO: Consider replacing src/dst memref indices with view memrefs. -class DmaStartOp - : public Op { -public: - using Op::Op; - static ArrayRef getAttributeNames() { return {}; } - - static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, - ValueRange srcIndices, Value destMemRef, - ValueRange destIndices, Value numElements, Value tagMemRef, - ValueRange tagIndices, Value stride = nullptr, - Value elementsPerStride = nullptr); - - // Returns the source MemRefType for this DMA operation. - Value getSrcMemRef() { return getOperand(0); } - // Returns the rank (number of indices) of the source MemRefType. - unsigned getSrcMemRefRank() { - return getSrcMemRef().getType().cast().getRank(); - } - // Returns the source memref indices for this DMA operation. - operand_range getSrcIndices() { - return {(*this)->operand_begin() + 1, - (*this)->operand_begin() + 1 + getSrcMemRefRank()}; - } - - // Returns the destination MemRefType for this DMA operations. - Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } - // Returns the rank (number of indices) of the destination MemRefType. - unsigned getDstMemRefRank() { - return getDstMemRef().getType().cast().getRank(); - } - unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpaceAsInt(); - } - unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpaceAsInt(); - } - - // Returns the destination memref indices for this DMA operation. - operand_range getDstIndices() { - return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1, - (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 + - getDstMemRefRank()}; - } - - // Returns the number of elements being transferred by this DMA operation. - Value getNumElements() { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); - } - - // Returns the Tag MemRef for this DMA operation. - Value getTagMemRef() { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); - } - // Returns the rank (number of indices) of the tag MemRefType. - unsigned getTagMemRefRank() { - return getTagMemRef().getType().cast().getRank(); - } - - // Returns the tag memref index for this DMA operation. - operand_range getTagIndices() { - unsigned tagIndexStartPos = - 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; - return {(*this)->operand_begin() + tagIndexStartPos, - (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()}; - } - - /// Returns true if this is a DMA from a faster memory space to a slower one. - bool isDestMemorySpaceFaster() { - return (getSrcMemorySpace() < getDstMemorySpace()); - } - - /// Returns true if this is a DMA from a slower memory space to a faster one. - bool isSrcMemorySpaceFaster() { - // Assumes that a lower number is for a slower memory space. - return (getDstMemorySpace() < getSrcMemorySpace()); - } - - /// Given a DMA start operation, returns the operand position of either the - /// source or destination memref depending on the one that is at the higher - /// level of the memory hierarchy. Asserts failure if neither is true. - unsigned getFasterMemPos() { - assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); - return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; - } - - static StringRef getOperationName() { return "memref.dma_start"; } - static ParseResult parse(OpAsmParser &parser, OperationState &result); - void print(OpAsmPrinter &p); - LogicalResult verify(); - - LogicalResult fold(ArrayRef cstOperands, - SmallVectorImpl &results); - - bool isStrided() { - return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + - 1 + 1 + getTagMemRefRank(); - } - - Value getStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1 - 1); - } - - Value getNumElementsPerStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1); - } -}; - -// DmaWaitOp blocks until the completion of a DMA operation associated with the -// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index -// with the same restrictions as any load/store index. %num_elements is the -// number of elements associated with the DMA operation. For example: -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : -// memref<2048 x f32>, (d0) -> (d0), 0>, -// memref<256 x f32>, (d0) -> (d0), 1> -// memref<1 x i32>, (d0) -> (d0), 2> -// ... -// ... -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> -// -class DmaWaitOp - : public Op { -public: - using Op::Op; - static ArrayRef getAttributeNames() { return {}; } - - static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, - ValueRange tagIndices, Value numElements); - - static StringRef getOperationName() { return "memref.dma_wait"; } - - // Returns the Tag MemRef associated with the DMA operation being waited on. - Value getTagMemRef() { return getOperand(0); } - - // Returns the tag memref index for this DMA operation. - operand_range getTagIndices() { - return {(*this)->operand_begin() + 1, - (*this)->operand_begin() + 1 + getTagMemRefRank()}; - } - - // Returns the rank (number of indices) of the tag memref. - unsigned getTagMemRefRank() { - return getTagMemRef().getType().cast().getRank(); - } - - // Returns the number of elements transferred in the associated DMA operation. - Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } - - static ParseResult parse(OpAsmParser &parser, OperationState &result); - void print(OpAsmPrinter &p); - LogicalResult fold(ArrayRef cstOperands, - SmallVectorImpl &results); - LogicalResult verify(); -}; -} // namespace memref -} // namespace mlir - #endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_ diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -284,8 +284,6 @@ let verifier = ?; } - - //===----------------------------------------------------------------------===// // BufferCastOp //===----------------------------------------------------------------------===// @@ -568,6 +566,217 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// DmaStartOp +//===----------------------------------------------------------------------===// + +def MemRef_DmaStartOp : MemRef_Op<"dma_start"> { + let summary = "non-blocking DMA operation that starts a transfer"; + let description = [{ + DmaStartOp starts a non-blocking DMA operation that transfers data from a + source memref to a destination memref. The source and destination memref + need not be of the same dimensionality, but need to have the same elemental + type. The operands include the source and destination memref's each followed + by its indices, size of the data transfer in terms of the number of elements + (of the elemental type of the memref), a tag memref with its indices, and + optionally at the end, a stride and a number_of_elements_per_stride + arguments. The tag location is used by a DmaWaitOp to check for completion. + The indices of the source memref, destination memref, and the tag memref + have the same restrictions as any load/store. The optional stride arguments + should be of 'index' type, and specify a stride for the slower memory space + (memory space with a lower memory space id), transferring chunks of + number_of_elements_per_stride every stride until %num_elements are + transferred. Either both or no stride arguments should be specified. If the + source and destination locations overlap the behavior of this operation is + not defined. + + For example, a DmaStartOp operation that transfers 256 elements of a memref + '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory + space 1 at indices [%k, %l], would be specified as follows: + + ```mlir + %num_elements = constant 256 + %idx = constant 0 : index + %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> + dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : + memref<40 x 128 x f32>, (d0) -> (d0), 0>, + memref<2 x 1024 x f32>, (d0) -> (d0), 1>, + memref<1 x i32>, (d0) -> (d0), 2> + ``` + + If %stride and %num_elt_per_stride are specified, the DMA is expected to + transfer %num_elt_per_stride elements every %stride elements apart from + memory space 0 until %num_elements are transferred. + + ```mlir + dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, + %num_elt_per_stride : + ``` + + TODO: add additional operands to allow source and destination striding, and + multiple stride levels. + TODO: Consider replacing src/dst memref indices with view memrefs. + }]; + let arguments = (ins Variadic:$operands); + + let builders = [ + OpBuilder<(ins "Value":$srcMemRef, "ValueRange":$srcIndices, + "Value":$destMemRef, "ValueRange":$destIndices, + "Value":$numElements, "Value":$tagMemRef, + "ValueRange":$tagIndices, CArg<"Value", "{}">:$stride, + CArg<"Value", "{}">:$elementsPerStride)> + ]; + + let extraClassDeclaration = [{ + // Returns the source MemRefType for this DMA operation. + Value getSrcMemRef() { return getOperand(0); } + // Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { + return getSrcMemRef().getType().cast().getRank(); + } + // Returns the source memref indices for this DMA operation. + operand_range getSrcIndices() { + return {(*this)->operand_begin() + 1, + (*this)->operand_begin() + 1 + getSrcMemRefRank()}; + } + + // Returns the destination MemRefType for this DMA operations. + Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } + // Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef().getType().cast().getRank(); + } + unsigned getSrcMemorySpace() { + return getSrcMemRef().getType().cast().getMemorySpaceAsInt(); + } + unsigned getDstMemorySpace() { + return getDstMemRef().getType().cast().getMemorySpaceAsInt(); + } + + // Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1, + (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank()}; + } + + // Returns the number of elements being transferred by this DMA operation. + Value getNumElements() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); + } + + // Returns the Tag MemRef for this DMA operation. + Value getTagMemRef() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); + } + // Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef().getType().cast().getRank(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + unsigned tagIndexStartPos = + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; + return {(*this)->operand_begin() + tagIndexStartPos, + (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower + /// one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster + /// one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; + } + + bool isStrided() { + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank() + 1 + 1 + + getTagMemRefRank(); + } + + Value getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + Value getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// DmaWaitOp +//===----------------------------------------------------------------------===// + +def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> { + let summary = "blocking DMA operation that waits for transfer completion"; + let description = [{ + DmaWaitOp blocks until the completion of a DMA operation associated with the + tag element '%tag[%index]'. %tag is a memref, and %index has to be an index + with the same restrictions as any load/store index. %num_elements is the + number of elements associated with the DMA operation. + + Example: + + ```mlir + dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : + memref<2048 x f32>, (d0) -> (d0), 0>, + memref<256 x f32>, (d0) -> (d0), 1> + memref<1 x i32>, (d0) -> (d0), 2> + ... + ... + dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> + ``` + }]; + let arguments = (ins + AnyMemRef:$tagMemRef, + Variadic:$tagIndices, + Index:$numElements + ); + let assemblyFormat = [{ + $tagMemRef `[` $tagIndices `]` `,` $numElements attr-dict `:` + type($tagMemRef) + }]; + let extraClassDeclaration = [{ + /// Returns the Tag MemRef associated with the DMA operation being waited + /// on. + Value getTagMemRef() { return tagMemRef(); } + + /// Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { return tagIndices(); } + + /// Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef().getType().cast().getRank(); + } + + /// Returns the number of elements transferred in the associated DMA + /// operation. + Value getNumElements() { return numElements(); } + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -34,9 +34,9 @@ } // end anonymous namespace void mlir::memref::MemRefDialect::initialize() { - addOperations(); + >(); addInterfaces(); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -909,16 +909,17 @@ result.addOperands({stride, elementsPerStride}); } -void DmaStartOp::print(OpAsmPrinter &p) { - p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " - << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() - << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; - if (isStrided()) - p << ", " << getStride() << ", " << getNumElementsPerStride(); +static void print(OpAsmPrinter &p, DmaStartOp op) { + p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " + << op.getDstMemRef() << '[' << op.getDstIndices() << "], " + << op.getNumElements() << ", " << op.getTagMemRef() << '[' + << op.getTagIndices() << ']'; + if (op.isStrided()) + p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() - << ", " << getTagMemRef().getType(); + p.printOptionalAttrDict(op->getAttrs()); + p << " : " << op.getSrcMemRef().getType() << ", " + << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); } // Parse DmaStartOp. @@ -929,7 +930,8 @@ // memref<1024 x f32, 2>, // memref<1 x i32> // -ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { +static ParseResult parseDmaStartOp(OpAsmParser &parser, + OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; @@ -989,66 +991,67 @@ return success(); } -LogicalResult DmaStartOp::verify() { - unsigned numOperands = getNumOperands(); +static LogicalResult verify(DmaStartOp op) { + unsigned numOperands = op.getNumOperands(); // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. if (numOperands < 4) - return emitOpError("expected at least 4 operands"); + return op.emitOpError("expected at least 4 operands"); // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!getSrcMemRef().getType().isa()) - return emitOpError("expected source to be of memref type"); - if (numOperands < getSrcMemRefRank() + 4) - return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 - << " operands"; - if (!getSrcIndices().empty() && - !llvm::all_of(getSrcIndices().getTypes(), + if (!op.getSrcMemRef().getType().isa()) + return op.emitOpError("expected source to be of memref type"); + if (numOperands < op.getSrcMemRefRank() + 4) + return op.emitOpError() + << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; + if (!op.getSrcIndices().empty() && + !llvm::all_of(op.getSrcIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return emitOpError("expected source indices to be of index type"); + return op.emitOpError("expected source indices to be of index type"); // 2. Destination memref. - if (!getDstMemRef().getType().isa()) - return emitOpError("expected destination to be of memref type"); - unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; + if (!op.getDstMemRef().getType().isa()) + return op.emitOpError("expected destination to be of memref type"); + unsigned numExpectedOperands = + op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) - return emitOpError() << "expected at least " << numExpectedOperands - << " operands"; - if (!getDstIndices().empty() && - !llvm::all_of(getDstIndices().getTypes(), + return op.emitOpError() + << "expected at least " << numExpectedOperands << " operands"; + if (!op.getDstIndices().empty() && + !llvm::all_of(op.getDstIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return emitOpError("expected destination indices to be of index type"); + return op.emitOpError("expected destination indices to be of index type"); // 3. Number of elements. - if (!getNumElements().getType().isIndex()) - return emitOpError("expected num elements to be of index type"); + if (!op.getNumElements().getType().isIndex()) + return op.emitOpError("expected num elements to be of index type"); // 4. Tag memref. - if (!getTagMemRef().getType().isa()) - return emitOpError("expected tag to be of memref type"); - numExpectedOperands += getTagMemRefRank(); + if (!op.getTagMemRef().getType().isa()) + return op.emitOpError("expected tag to be of memref type"); + numExpectedOperands += op.getTagMemRefRank(); if (numOperands < numExpectedOperands) - return emitOpError() << "expected at least " << numExpectedOperands - << " operands"; - if (!getTagIndices().empty() && - !llvm::all_of(getTagIndices().getTypes(), + return op.emitOpError() + << "expected at least " << numExpectedOperands << " operands"; + if (!op.getTagIndices().empty() && + !llvm::all_of(op.getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return emitOpError("expected tag indices to be of index type"); + return op.emitOpError("expected tag indices to be of index type"); // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && numOperands != numExpectedOperands + 2) - return emitOpError("incorrect number of operands"); + return op.emitOpError("incorrect number of operands"); // 5. Strides. - if (isStrided()) { - if (!getStride().getType().isIndex() || - !getNumElementsPerStride().getType().isIndex()) - return emitOpError( + if (op.isStrided()) { + if (!op.getStride().getType().isIndex() || + !op.getNumElementsPerStride().getType().isIndex()) + return op.emitOpError( "expected stride and num elements per stride to be of type index"); } @@ -1065,74 +1068,20 @@ // DmaWaitOp // --------------------------------------------------------------------------- -void DmaWaitOp::build(OpBuilder &builder, OperationState &result, - Value tagMemRef, ValueRange tagIndices, - Value numElements) { - result.addOperands(tagMemRef); - result.addOperands(tagIndices); - result.addOperands(numElements); -} - -void DmaWaitOp::print(OpAsmPrinter &p) { - p << " " << getTagMemRef() << '[' << getTagIndices() << "], " - << getNumElements(); - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : " << getTagMemRef().getType(); -} - -// Parse DmaWaitOp. -// Eg: -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> -// -ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType tagMemrefInfo; - SmallVector tagIndexInfos; - Type type; - auto indexType = parser.getBuilder().getIndexType(); - OpAsmParser::OperandType numElementsInfo; - - // Parse tag memref, its indices, and dma size. - if (parser.parseOperand(tagMemrefInfo) || - parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || - parser.parseComma() || parser.parseOperand(numElementsInfo) || - parser.parseColonType(type) || - parser.resolveOperand(tagMemrefInfo, type, result.operands) || - parser.resolveOperands(tagIndexInfos, indexType, result.operands) || - parser.resolveOperand(numElementsInfo, indexType, result.operands)) - return failure(); - - return success(); -} - LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } -LogicalResult DmaWaitOp::verify() { - // Mandatory non-variadic operands are tag and the number of elements. - if (getNumOperands() < 2) - return emitOpError() << "expected at least 2 operands"; - - // Check types of operands. The order of these calls is important: the later - // calls rely on some type properties to compute the operand position. - if (!getTagMemRef().getType().isa()) - return emitOpError() << "expected tag to be of memref type"; - - if (getNumOperands() != 2 + getTagMemRefRank()) - return emitOpError() << "expected " << 2 + getTagMemRefRank() - << " operands"; - - if (!getTagIndices().empty() && - !llvm::all_of(getTagIndices().getTypes(), - [](Type t) { return t.isIndex(); })) - return emitOpError() << "expected tag indices to be of index type"; - - if (!getNumElements().getType().isIndex()) - return emitOpError() - << "expected the number of elements to be of index type"; - +static LogicalResult verify(DmaWaitOp op) { + // Check that the number of tag indices matches the tagMemRef rank. + unsigned numTagIndices = op.tagIndices().size(); + unsigned tagMemRefRank = op.getTagMemRefRank(); + if (numTagIndices != tagMemRefRank) + return op.emitOpError() << "expected tagIndices to have the same number of " + "elements as the tagMemRef rank, expected " + << tagMemRefRank << ", but got " << numTagIndices; return success(); } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1,5 +1,132 @@ // RUN: mlir-opt -split-input-file %s -verify-diagnostics +func @dma_start_not_enough_operands() { + // expected-error@+1 {{expected at least 4 operands}} + "memref.dma_start"() : () -> () +} + +// ----- + +func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) { + // expected-error@+1 {{expected source to be of memref type}} + memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32 +} + +// ----- + +func @dma_start_not_enough_operands_for_src( + %src: memref<2x2x2xf32>, %idx: index) { + // expected-error@+1 {{expected at least 7 operands}} + "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> () +} + +// ----- + +func @dma_start_src_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected source indices to be of index type}} + "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx) + : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref, index) -> () +} + +// ----- + +func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) { + %mref = memref.alloc() : memref<8 x f32> + // expected-error@+1 {{expected destination to be of memref type}} + memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32 +} + +// ----- + +func @dma_start_not_enough_operands_for_dst( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref) { + // expected-error@+1 {{expected at least 7 operands}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected destination indices to be of index type}} + "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref, index) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected num elements to be of index type}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref) -> () +} + +// ----- + +func @dma_no_tag_memref(%tag : f32, %c0 : index) { + %mref = memref.alloc() : memref<8 x f32> + // expected-error@+1 {{expected tag to be of memref type}} + memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32 +} + +// ----- + +func @dma_start_not_enough_operands_for_tag( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref<2xi32,2>) { + // expected-error@+1 {{expected at least 8 operands}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref<2xi32,2>, %flt: f32) { + // expected-error@+1 {{expected tag indices to be of index type}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> () +} + +// ----- + +func @dma_start_too_many_operands( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref) { + // expected-error@+1 {{incorrect number of operands}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, index, index) -> () +} + + +// ----- + +func @dma_start_wrong_stride_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected stride and num elements per stride to be of type index}} + "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, f32) -> () +} + +// ----- + +func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt: index) { + // expected-error@+1 {{expected tagIndices to have the same number of elements as the tagMemRef rank, expected 2, but got 1}} + "memref.dma_wait"(%tag, %flt, %idx) : (memref<2x2xi32>, index, index) -> () + return +} + +// ----- + func @transpose_not_permutation(%v : memref(off + M * i + j)>>) { // expected-error @+1 {{expected a permutation map}} memref.transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -290,153 +290,6 @@ // ----- -func @dma_start_not_enough_operands() { - // expected-error@+1 {{expected at least 4 operands}} - "memref.dma_start"() : () -> () -} - -// ----- - -func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) { - // expected-error@+1 {{expected source to be of memref type}} - memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32 -} - -// ----- - -func @dma_start_not_enough_operands_for_src( - %src: memref<2x2x2xf32>, %idx: index) { - // expected-error@+1 {{expected at least 7 operands}} - "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> () -} - -// ----- - -func @dma_start_src_index_wrong_type( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref, %flt: f32) { - // expected-error@+1 {{expected source indices to be of index type}} - "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx) - : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref, index) -> () -} - -// ----- - -func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) { - %mref = memref.alloc() : memref<8 x f32> - // expected-error@+1 {{expected destination to be of memref type}} - memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32 -} - -// ----- - -func @dma_start_not_enough_operands_for_dst( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref) { - // expected-error@+1 {{expected at least 7 operands}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> () -} - -// ----- - -func @dma_start_dst_index_wrong_type( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref, %flt: f32) { - // expected-error@+1 {{expected destination indices to be of index type}} - "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref, index) -> () -} - -// ----- - -func @dma_start_dst_index_wrong_type( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref, %flt: f32) { - // expected-error@+1 {{expected num elements to be of index type}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref) -> () -} - -// ----- - -func @dma_no_tag_memref(%tag : f32, %c0 : index) { - %mref = memref.alloc() : memref<8 x f32> - // expected-error@+1 {{expected tag to be of memref type}} - memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32 -} - -// ----- - -func @dma_start_not_enough_operands_for_tag( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref<2xi32,2>) { - // expected-error@+1 {{expected at least 8 operands}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> () -} - -// ----- - -func @dma_start_dst_index_wrong_type( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref<2xi32,2>, %flt: f32) { - // expected-error@+1 {{expected tag indices to be of index type}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> () -} - -// ----- - -func @dma_start_too_many_operands( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref) { - // expected-error@+1 {{incorrect number of operands}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, index, index) -> () -} - - -// ----- - -func @dma_start_wrong_stride_type( - %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, - %tag: memref, %flt: f32) { - // expected-error@+1 {{expected stride and num elements per stride to be of type index}} - "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt) - : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, f32) -> () -} - -// ----- - -func @dma_wait_not_enough_operands() { - // expected-error@+1 {{expected at least 2 operands}} - "memref.dma_wait"() : () -> () -} - -// ----- - -func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) { - // expected-error@+1 {{expected tag to be of memref type}} - "memref.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> () -} - -// ----- - -func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { - // expected-error@+1 {{expected tag indices to be of index type}} - "memref.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> () -} - -// ----- - -func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { - // expected-error@+1 {{expected the number of elements to be of index type}} - "memref.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> () -} - -// ----- - func @invalid_cmp_attr(%idx : i32) { // expected-error@+1 {{expected string or keyword containing one of the following enum values}} %cmp = cmpi i1, %idx, %idx : i32