Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1503,26 +1503,13 @@ if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { // Memref has vector element type. - // Check that 'memrefVectorElementType' and vector element types match. - if (memrefVectorElementType.getElementType() != vectorType.getElementType()) - return op->emitOpError( - "requires memref and vector types of the same elemental type"); - // Check that memref vector type is a suffix of 'vectorType. unsigned memrefVecEltRank = memrefVectorElementType.getRank(); unsigned resultVecRank = vectorType.getRank(); if (memrefVecEltRank > resultVecRank) return op->emitOpError( "requires memref vector element and vector result ranks to match."); - // TODO: Move this to isSuffix in Vector/Utils.h. unsigned rankOffset = resultVecRank - memrefVecEltRank; - auto memrefVecEltShape = memrefVectorElementType.getShape(); - auto resultVecShape = vectorType.getShape(); - for (unsigned i = 0; i < memrefVecEltRank; ++i) - if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) - return op->emitOpError( - "requires memref vector element shape to match suffix of " - "vector result shape."); // Check that permutation map results match 'rankOffset' of vector type. if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " @@ -1530,11 +1517,6 @@ } else { // Memref has scalar element type. - // Check that memref and vector element types match. - if (memrefType.getElementType() != vectorType.getElementType()) - return op->emitOpError( - "requires memref and vector types of the same elemental type"); - // Check that permutation map results match rank of vector type. if (permutationMap.getNumResults() != vectorType.getRank()) return op->emitOpError("requires a permutation_map with result dims of " @@ -1563,7 +1545,7 @@ VectorType vector, Value memref, ValueRange indices, AffineMap permutationMap, ArrayRef maybeMasked) { - Type elemType = vector.cast().getElementType(); + Type elemType = memref.getType().cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -343,32 +343,12 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{requires memref and vector types of the same elemental type}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xi32> -} - -// ----- - -func @test_vector.transfer_read(%arg0: memref>) { - %c3 = constant 3 : index - %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> // expected-error@+1 {{requires memref vector element and vector result ranks to match}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<3xf32> } // ----- -func @test_vector.transfer_read(%arg0: memref>) { - %c3 = constant 3 : index - %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> -} - -// ----- - func @test_vector.transfer_read(%arg0: memref>) { %c3 = constant 3 : index %f0 = constant 0.0 : f32 Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -4,12 +4,15 @@ // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, - %arg1 : memref>) { + %arg1 : memref>, + %arg2 : memref>) { // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %c0 = constant 0 : i32 %vf0 = splat %f0 : vector<4x3xf32> + %v0 = splat %c0 : vector<4x3xi32> // // CHECK: vector.transfer_read @@ -24,6 +27,9 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref>, vector<1x1x4x3xf32> %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x4xi8> + %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref>, vector<5x4xi8> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -33,6 +39,8 @@ vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x4xi8>, memref> + vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x4xi8>, memref> return }