diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -286,6 +286,7 @@ void print(OpAsmPrinter &p); LogicalResult fold(ArrayRef cstOperands, SmallVectorImpl &results); + LogicalResult verify(); }; /// Prints dimension and symbol list. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1444,49 +1444,81 @@ parser.resolveOperands(tagIndexInfos, indexType, result.operands)) return failure(); - auto memrefType0 = types[0].dyn_cast(); - if (!memrefType0) - return parser.emitError(parser.getNameLoc(), - "expected source to be of memref type"); - - auto memrefType1 = types[1].dyn_cast(); - if (!memrefType1) - return parser.emitError(parser.getNameLoc(), - "expected destination to be of memref type"); - - auto memrefType2 = types[2].dyn_cast(); - if (!memrefType2) - return parser.emitError(parser.getNameLoc(), - "expected tag to be of memref type"); - if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } - // Check that source/destination index list size matches associated rank. - if (static_cast(srcIndexInfos.size()) != memrefType0.getRank() || - static_cast(dstIndexInfos.size()) != memrefType1.getRank()) - return parser.emitError(parser.getNameLoc(), - "memref rank not equal to indices count"); - if (static_cast(tagIndexInfos.size()) != memrefType2.getRank()) - return parser.emitError(parser.getNameLoc(), - "tag memref rank not equal to indices count"); return success(); } LogicalResult DmaStartOp::verify() { + unsigned numOperands = getNumOperands(); + + // Mandatory non-variadic operands are: src memref, dst memref, tag memref and + // the number of elements. + if (numOperands < 4) + return emitOpError("expected at least 4 operands"); + + // Check types of operands. The order of these calls is important: the later + // calls rely on some type properties to compute the operand position. + // 1. Source memref. + if (!getSrcMemRef().getType().isa()) + return emitOpError("expected source to be of memref type"); + if (numOperands < getSrcMemRefRank() + 4) + return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 + << " operands"; + if (!getSrcIndices().empty() && + !llvm::all_of(getSrcIndices().getTypes(), + [](Type t) { return t.isIndex(); })) + return emitOpError("expected source indices to be of index type"); + + // 2. Destination memref. + if (!getDstMemRef().getType().isa()) + return emitOpError("expected destination to be of memref type"); + unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; + if (numOperands < numExpectedOperands) + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getDstIndices().empty() && + !llvm::all_of(getDstIndices().getTypes(), + [](Type t) { return t.isIndex(); })) + return emitOpError("expected destination indices to be of index type"); + + // 3. Number of elements. + if (!getNumElements().getType().isIndex()) + return emitOpError("expected num elements to be of index type"); + + // 4. Tag memref. + if (!getTagMemRef().getType().isa()) + return emitOpError("expected tag to be of memref type"); + numExpectedOperands += getTagMemRefRank(); + if (numOperands < numExpectedOperands) + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getTagIndices().empty() && + !llvm::all_of(getTagIndices().getTypes(), + [](Type t) { return t.isIndex(); })) + return emitOpError("expected tag indices to be of index type"); + // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) return emitOpError("DMA should be between different memory spaces"); - if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 && - getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 + 2) { + // Stride-related operands are optional. + if (numOperands != numExpectedOperands && + numOperands != numExpectedOperands + 2) return emitOpError("incorrect number of operands"); + + // 5. Strides. + if (isStrided()) { + if (!getStride().getType().isIndex() || + !getNumElementsPerStride().getType().isIndex()) + return emitOpError( + "expected stride and num elements per stride to be of type index"); } + return success(); } @@ -1536,15 +1568,6 @@ parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); - auto memrefType = type.dyn_cast(); - if (!memrefType) - return parser.emitError(parser.getNameLoc(), - "expected tag to be of memref type"); - - if (static_cast(tagIndexInfos.size()) != memrefType.getRank()) - return parser.emitError(parser.getNameLoc(), - "tag memref rank not equal to indices count"); - return success(); } @@ -1554,6 +1577,32 @@ return foldMemRefCast(*this); } +LogicalResult DmaWaitOp::verify() { + // Mandatory non-variadic operands are tag and the number of elements. + if (getNumOperands() < 2) + return emitOpError() << "expected at least 2 operands"; + + // Check types of operands. The order of these calls is important: the later + // calls rely on some type properties to compute the operand position. + if (!getTagMemRef().getType().isa()) + return emitOpError() << "expected tag to be of memref type"; + + if (getNumOperands() != 2 + getTagMemRefRank()) + return emitOpError() << "expected " << 2 + getTagMemRefRank() + << " operands"; + + if (!getTagIndices().empty() && + !llvm::all_of(getTagIndices().getTypes(), + [](Type t) { return t.isIndex(); })) + return emitOpError() << "expected tag indices to be of index type"; + + if (!getNumElements().getType().isIndex()) + return emitOpError() + << "expected the number of elements to be of index type"; + + return success(); +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -303,6 +303,13 @@ // ----- +func @dma_start_not_enough_operands() { + // expected-error@+1 {{expected at least 4 operands}} + "std.dma_start"() : () -> () +} + +// ----- + func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) { // expected-error@+1 {{expected source to be of memref type}} dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32 @@ -310,6 +317,24 @@ // ----- +func @dma_start_not_enough_operands_for_src( + %src: memref<2x2x2xf32>, %idx: index) { + // expected-error@+1 {{expected at least 7 operands}} + "std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> () +} + +// ----- + +func @dma_start_src_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected source indices to be of index type}} + "std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx) + : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref, index) -> () +} + +// ----- + func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) { %mref = alloc() : memref<8 x f32> // expected-error@+1 {{expected destination to be of memref type}} @@ -318,6 +343,36 @@ // ----- +func @dma_start_not_enough_operands_for_dst( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref) { + // expected-error@+1 {{expected at least 7 operands}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected destination indices to be of index type}} + "std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref, index) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected num elements to be of index type}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref) -> () +} + +// ----- + func @dma_no_tag_memref(%tag : f32, %c0 : index) { %mref = alloc() : memref<8 x f32> // expected-error@+1 {{expected tag to be of memref type}} @@ -326,9 +381,80 @@ // ----- +func @dma_start_not_enough_operands_for_tag( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref<2xi32,2>) { + // expected-error@+1 {{expected at least 8 operands}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> () +} + +// ----- + +func @dma_start_dst_index_wrong_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref<2xi32,2>, %flt: f32) { + // expected-error@+1 {{expected tag indices to be of index type}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> () +} + +// ----- + +func @dma_start_same_space( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>, + %tag: memref) { + // expected-error@+1 {{DMA should be between different memory spaces}} + dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref +} + +// ----- + +func @dma_start_too_many_operands( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref) { + // expected-error@+1 {{incorrect number of operands}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, index, index) -> () +} + + +// ----- + +func @dma_start_wrong_stride_type( + %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, + %tag: memref, %flt: f32) { + // expected-error@+1 {{expected stride and num elements per stride to be of type index}} + "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt) + : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, f32) -> () +} + +// ----- + +func @dma_wait_not_enough_operands() { + // expected-error@+1 {{expected at least 2 operands}} + "std.dma_wait"() : () -> () +} + +// ----- + func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) { // expected-error@+1 {{expected tag to be of memref type}} - dma_wait %tag[%c0], %arg0 : f32 + "std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> () +} + +// ----- + +func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { + // expected-error@+1 {{expected tag indices to be of index type}} + "std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> () +} + +// ----- + +func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { + // expected-error@+1 {{expected the number of elements to be of index type}} + "std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> () } // -----