Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1503,26 +1503,13 @@ if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { // Memref has vector element type. - // Check that 'memrefVectorElementType' and vector element types match. - if (memrefVectorElementType.getElementType() != vectorType.getElementType()) - return op->emitOpError( - "requires memref and vector types of the same elemental type"); - // Check that memref vector type is a suffix of 'vectorType. unsigned memrefVecEltRank = memrefVectorElementType.getRank(); unsigned resultVecRank = vectorType.getRank(); if (memrefVecEltRank > resultVecRank) return op->emitOpError( "requires memref vector element and vector result ranks to match."); - // TODO: Move this to isSuffix in Vector/Utils.h. unsigned rankOffset = resultVecRank - memrefVecEltRank; - auto memrefVecEltShape = memrefVectorElementType.getShape(); - auto resultVecShape = vectorType.getShape(); - for (unsigned i = 0; i < memrefVecEltRank; ++i) - if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) - return op->emitOpError( - "requires memref vector element shape to match suffix of " - "vector result shape."); // Check that permutation map results match 'rankOffset' of vector type. if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " @@ -1530,11 +1517,6 @@ } else { // Memref has scalar element type. - // Check that memref and vector element types match. - if (memrefType.getElementType() != vectorType.getElementType()) - return op->emitOpError( - "requires memref and vector types of the same elemental type"); - // Check that permutation map results match rank of vector type. if (permutationMap.getNumResults() != vectorType.getRank()) return op->emitOpError("requires a permutation_map with result dims of " @@ -1563,7 +1545,7 @@ VectorType vector, Value memref, ValueRange indices, AffineMap permutationMap, ArrayRef maybeMasked) { - Type elemType = vector.cast().getElementType(); + Type elemType = memref.getType().cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) @@ -1676,9 +1658,9 @@ return op.emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. - if (paddingType != vectorType.getElementType()) + if (paddingType != memrefElementType) return op.emitOpError( - "requires formal padding and vector of the same elemental type"); + "requires formal padding and memref of the same elemental type"); } return verifyPermutationMap(permutationMap, Index: mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir =================================================================== --- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -937,6 +937,25 @@ // 2. Rewrite as a load. // CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x float>*"> +func @transfer_read_1d_cast(%A : memref, %base: index) -> vector<17xi8> { + %c0 = constant 0: i32 + %v = vector.transfer_read %A[%base], %c0 {masked = [false]} : + memref, vector<17xi8> + return %v: vector<17xi8> +} +// CHECK-LABEL: func @transfer_read_1d_cast +// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x i8>"> +// +// 1. Bitcast to vector form. +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm<"i32*">, !llvm.i64) -> !llvm<"i32*"> +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm<"i32*"> to !llvm<"<17 x i8>*"> +// +// 2. Rewrite as a load. +// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x i8>*"> + + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -343,32 +343,12 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{requires memref and vector types of the same elemental type}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xi32> -} - -// ----- - -func @test_vector.transfer_read(%arg0: memref>) { - %c3 = constant 3 : index - %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> // expected-error@+1 {{requires memref vector element and vector result ranks to match}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<3xf32> } // ----- -func @test_vector.transfer_read(%arg0: memref>) { - %c3 = constant 3 : index - %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> -} - -// ----- - func @test_vector.transfer_read(%arg0: memref>) { %c3 = constant 3 : index %f0 = constant 0.0 : f32 Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -4,12 +4,15 @@ // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, - %arg1 : memref>) { + %arg1 : memref>, + %arg2 : memref>) { // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %c0 = constant 0 : i32 %vf0 = splat %f0 : vector<4x3xf32> + %v0 = splat %c0 : vector<4x3xi32> // // CHECK: vector.transfer_read @@ -24,6 +27,9 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref>, vector<1x1x4x3xf32> %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x4xi8> + %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref>, vector<5x4xi8> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -33,6 +39,8 @@ vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x4xi8>, memref> + vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x4xi8>, memref> return }