diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1244,7 +1244,7 @@ ```mlir vector.scatter %base, %indices, %mask, %value - : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref ``` }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -147,28 +147,13 @@ offset != 0 || memRefType.getMemorySpace() != 0) return failure(); - // Base pointer. + // Create a vector of pointers from base and indices. MemRefDescriptor memRefDescriptor(memref); Value base = memRefDescriptor.alignedPtr(rewriter, loc); - - // Create a vector of pointers from base and indices. - // - // TODO: this step serializes the address computations unfortunately, - // ideally we would like to add splat(base) + index_vector - // in SIMD form, but this does not match well with current - // constraints of the standard and vector dialect.... - // int64_t size = vType.getDimSize(0); auto pType = memRefDescriptor.getElementType(); auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size); - auto idxType = typeConverter.convertType(iType); - ptrs = rewriter.create(loc, ptrsType); - for (int64_t i = 0; i < size; i++) { - Value off = - extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i); - Value ptr = rewriter.create(loc, pType, base, off); - ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i); - } + ptrs = rewriter.create(loc, ptrsType, base, indices); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -976,7 +976,8 @@ } // CHECK-LABEL: func @gather_op -// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm<"float*">, !llvm<"<3 x i32>">) -> !llvm<"<3 x float*>"> +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> // CHECK: llvm.return %[[G]] : !llvm<"<3 x float>"> func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { @@ -985,5 +986,6 @@ } // CHECK-LABEL: func @scatter_op -// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>"> +// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm<"float*">, !llvm<"<3 x i32>">) -> !llvm<"<3 x float*>"> +// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>"> // CHECK: llvm.return