diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1147,6 +1147,26 @@ Type i64Type = rewriter.getIntegerType(64); MemRefType memRefType = xferOp.getMemRefType(); + if (auto memrefVectorElementType = + memRefType.getElementType().dyn_cast()) { + // Memref has vector element type. + if (memrefVectorElementType.getElementType() != + xferOp.getVectorType().getElementType()) + return failure(); + // Check that memref vector type is a suffix of 'vectorType. + unsigned memrefVecEltRank = memrefVectorElementType.getRank(); + unsigned resultVecRank = xferOp.getVectorType().getRank(); + assert(memrefVecEltRank <= resultVecRank); + // TODO: Move this to isSuffix in Vector/Utils.h. + unsigned rankOffset = resultVecRank - memrefVecEltRank; + auto memrefVecEltShape = memrefVectorElementType.getShape(); + auto resultVecShape = xferOp.getVectorType().getShape(); + for (unsigned i = 0; i < memrefVecEltRank; ++i) + assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && + "memref vector element shape should match suffix of vector " + "result shape."); + } + // 1. Get the source/dst address as an LLVM vector pointer. // The vector pointer would always be on address space 0, therefore // addrspacecast shall be used when source/dst memrefs are not on diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1500,37 +1500,33 @@ if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { // Memref has vector element type. - // Check that 'memrefVectorElementType' and vector element types match. - if (memrefVectorElementType.getElementType() != vectorType.getElementType()) + unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() * + memrefVectorElementType.getShape().back(); + unsigned resultVecSize = + vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); + if (resultVecSize % memrefVecSize != 0) return op->emitOpError( - "requires memref and vector types of the same elemental type"); + "requires the bitwidth of the minor 1-D vector to be an integral " + "multiple of the bitwidth of the minor 1-D vector of the memref"); - // Check that memref vector type is a suffix of 'vectorType. unsigned memrefVecEltRank = memrefVectorElementType.getRank(); unsigned resultVecRank = vectorType.getRank(); if (memrefVecEltRank > resultVecRank) return op->emitOpError( "requires memref vector element and vector result ranks to match."); - // TODO: Move this to isSuffix in Vector/Utils.h. unsigned rankOffset = resultVecRank - memrefVecEltRank; - auto memrefVecEltShape = memrefVectorElementType.getShape(); - auto resultVecShape = vectorType.getShape(); - for (unsigned i = 0; i < memrefVecEltRank; ++i) - if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) - return op->emitOpError( - "requires memref vector element shape to match suffix of " - "vector result shape."); // Check that permutation map results match 'rankOffset' of vector type. if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); } else { // Memref has scalar element type. - - // Check that memref and vector element types match. - if (memrefType.getElementType() != vectorType.getElementType()) + unsigned resultVecSize = + vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); + if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0) return op->emitOpError( - "requires memref and vector types of the same elemental type"); + "requires the bitwidth of the minor 1-D vector to be an integral " + "multiple of the bitwidth of the memref element type"); // Check that permutation map results match rank of vector type. if (permutationMap.getNumResults() != vectorType.getRank()) @@ -1560,7 +1556,7 @@ VectorType vector, Value memref, ValueRange indices, AffineMap permutationMap, ArrayRef maybeMasked) { - Type elemType = vector.cast().getElementType(); + Type elemType = memref.getType().cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) @@ -1673,9 +1669,9 @@ return op.emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. - if (paddingType != vectorType.getElementType()) + if (paddingType != memrefElementType) return op.emitOpError( - "requires formal padding and vector of the same elemental type"); + "requires formal padding and memref of the same elemental type"); } return verifyPermutationMap(permutationMap, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -906,6 +906,24 @@ // 2. Rewrite as a load. // CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr> +func @transfer_read_1d_cast(%A : memref, %base: index) -> vector<12xi8> { + %c0 = constant 0: i32 + %v = vector.transfer_read %A[%base], %c0 {masked = [false]} : + memref, vector<12xi8> + return %v: vector<12xi8> +} +// CHECK-LABEL: func @transfer_read_1d_cast +// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<12 x i8> +// +// 1. Bitcast to vector form. +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// +// 2. Rewrite as a load. +// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr> + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -339,16 +339,6 @@ // ----- -func @test_vector.transfer_read(%arg0: memref>) { - %c3 = constant 3 : index - %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{requires memref and vector types of the same elemental type}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xi32> -} - -// ----- - func @test_vector.transfer_read(%arg0: memref>) { %c3 = constant 3 : index %f0 = constant 0.0 : f32 @@ -359,12 +349,12 @@ // ----- -func @test_vector.transfer_read(%arg0: memref>) { +func @test_vector.transfer_read(%arg0: memref>) { %c3 = constant 3 : index %f0 = constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> + %vf0 = splat %f0 : vector<6xf32> + // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref>, vector<3xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -4,12 +4,15 @@ // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, - %arg1 : memref>) { + %arg1 : memref>, + %arg2 : memref>) { // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %c0 = constant 0 : i32 %vf0 = splat %f0 : vector<4x3xf32> + %v0 = splat %c0 : vector<4x3xi32> // // CHECK: vector.transfer_read @@ -24,6 +27,9 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref>, vector<1x1x4x3xf32> %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x24xi8> + %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref>, vector<5x24xi8> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -33,6 +39,8 @@ vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref> + vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref> return }