diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2365,10 +2365,12 @@ VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) @@ -2410,10 +2412,12 @@ static LogicalResult verify(MaskedStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); @@ -2454,10 +2458,10 @@ VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) @@ -2500,10 +2504,10 @@ VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected value dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) @@ -2544,10 +2548,12 @@ VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) @@ -2589,10 +2595,12 @@ static LogicalResult verify(CompressStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1222,6 +1222,13 @@ // ----- +func @maskedload_memref_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedload' op requires 1 indices}} + %0 = vector.maskedload %base[], %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}} @@ -1238,6 +1245,14 @@ // ----- +func @maskedstore_memref_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.maskedstore' op requires 1 indices}} + vector.maskedstore %base[%c0, %c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> +} + +// ----- + func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op base and result element type should match}} @@ -1343,6 +1358,14 @@ // ----- +func @expand_memref_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.expandload' op requires 2 indices}} + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func @compress_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index // expected-error@+1 {{'vector.compressstore' op base and value element type should match}} @@ -1359,6 +1382,14 @@ // ----- +func @compress_memref_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.compressstore' op requires 2 indices}} + vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> +} + +// ----- + func @extract_map_rank(%v: vector<32xf32>, %id : index) { // expected-error@+1 {{'vector.extract_map' op expected source and destination vectors of same rank}} %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32>