diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -991,6 +991,35 @@ let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " "type($value) `,` type($mask) `into` type($data)"; } + +/// Create a call to Masked Gather intrinsic. +def LLVM_masked_gather + : LLVM_OneResultOp<"intr.masked.gather">, + Arguments<(ins LLVM_Type:$ptrs, LLVM_Type:$mask, + Variadic:$pass_thru, I32Attr:$alignment)> { + string llvmBuilder = [{ + $res = $pass_thru.empty() ? builder.CreateMaskedGather( + $ptrs, llvm::Align($alignment.getZExtValue()), $mask) : + builder.CreateMaskedGather( + $ptrs, llvm::Align($alignment.getZExtValue()), $mask, $pass_thru[0]); + }]; + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +/// Create a call to Masked Scatter intrinsic. +def LLVM_masked_scatter + : LLVM_ZeroResultOp<"intr.masked.scatter">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask, + I32Attr:$alignment)> { + string llvmBuilder = [{ + builder.CreateMaskedScatter( + $value, $ptrs, llvm::Align($alignment.getZExtValue()), $mask); + }]; + let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` " + "type($value) `,` type($mask) `into` type($ptrs)"; +} + // // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1150,6 +1150,121 @@ let hasFolder = 1; } +def Vector_GatherOp : + Vector_Op<"gather">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [AnyInteger]>:$indices, + VectorOfRankAndType<[1], [I1]>:$mask, + Variadic>:$pass_thru)>, + Results<(outs VectorOfRank<[1]>:$result)> { + + let summary = "gathers elements from memory into a vector as defined by an index vector"; + + let description = [{ + The gather operation gathers elements from memory into a 1-D vector as + defined by a base and a 1-D index vector, but only if the corresponding + bit is set in a 1-D mask vector. Otherwise, the element is taken from a + 1-D pass-through vector, if provided, or left undefined. Informally the + semantics are: + ``` + if (!defined(pass_thru)) pass_thru = [undef, .., undef] + result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0] + result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1] + etc. + ``` + The vector dialect leaves out-of-bounds behavior undefined. + + The gather operation can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a gather. The semantics of the operation closely + correspond to those of the `llvm.masked.gather` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics). + + Example: + + ```mlir + %g = vector.gather %base, %indices, %mask, %pass_thru + : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + ``` + + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getIndicesVectorType() { + return indices().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getPassThruVectorType() { + return (llvm::size(pass_thru()) == 0) + ? VectorType() + : (*pass_thru().begin()).getType().cast(); + } + VectorType getResultVectorType() { + return result().getType().cast(); + } + }]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + +def Vector_ScatterOp : + Vector_Op<"scatter">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [AnyInteger]>:$indices, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$value)> { + + let summary = "scatters elements from a vector into memory as defined by an index vector"; + + let description = [{ + The scatter operation scatters elements from a 1-D vector into memory as + defined by a base and a 1-D index vector, but only if the corresponding + bit in a 1-D mask vector is set. Otherwise, no action is taken for that + element. Informally the semantics are: + ``` + if (mask[0]) MEM[base + index[0]] = value[0] + if (mask[1]) MEM[base + index[1]] = value[1] + etc. + ``` + The vector dialect leaves out-of-bounds and repeated index behavior + undefined. Underlying implementations may enforce strict sequential + semantics for the latter, though. + TODO: enforce the latter always? + + The scatter operation can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a scatter. The semantics of the operation closely + correspond to those of the `llvm.masked.scatter` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics). + + Example: + + ```mlir + vector.scatter %base, %indices, %mask, %value + : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getIndicesVectorType() { + return indices().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getValueVectorType() { + return value().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` " + "type($indices) `,` type($mask) `,` type($value) `into` type($base)"; +} + def Vector_ShapeCastOp : Vector_Op<"shape_cast", [NoSideEffect]>, Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>, diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @gather8(%base: memref, + %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> { + %g = vector.gather %base, %indices, %mask + : (memref, vector<8xi32>, vector<8xi1>) -> vector<8xf32> + return %g : vector<8xf32> +} + +func @gather_pass_thru8(%base: memref, + %indices: vector<8xi32>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %g = vector.gather %base, %indices, %mask, %pass_thru + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32> + return %g : vector<8xf32> +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c10 = constant 10: index + %A = alloc(%c10) : memref + scf.for %i = %c0 to %c10 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up idx vector. + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %i2 = constant 2: i32 + %i3 = constant 3: i32 + %i4 = constant 4: i32 + %i5 = constant 5: i32 + %i6 = constant 6: i32 + %i9 = constant 9: i32 + %0 = vector.broadcast %i0 : i32 to vector<8xi32> + %1 = vector.insert %i6, %0[1] : i32 into vector<8xi32> + %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32> + %3 = vector.insert %i3, %2[3] : i32 into vector<8xi32> + %4 = vector.insert %i5, %3[4] : i32 into vector<8xi32> + %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32> + %6 = vector.insert %i9, %5[6] : i32 into vector<8xi32> + %idx = vector.insert %i2, %6[7] : i32 into vector<8xi32> + + // Set up pass thru vector. + %u = constant -7.0: f32 + %pass = vector.broadcast %u : f32 to vector<8xf32> + + // Set up masks. + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<8xi1> + %all = vector.constant_mask [8] : vector<8xi1> + %some = vector.constant_mask [4] : vector<8xi1> + %more = vector.insert %t, %some[7] : i1 into vector<8xi1> + + // + // Gather tests. + // + + %g1 = call @gather8(%A, %idx, %all) + : (memref, vector<8xi32>, vector<8xi1>) + -> (vector<8xf32>) + vector.print %g1 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 ) + + %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g2 : vector<8xf32> + // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 ) + + %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g3 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 ) + + %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g4 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 ) + + %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g5 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 ) + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @scatter8(%base: memref, + %indices: vector<8xi32>, + %mask: vector<8xi1>, %value: vector<8xf32>) { + vector.scatter %base, %indices, %mask, %value + : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref + return +} + +func @printmem(%A: memref) { + %f = constant 0.0: f32 + %0 = vector.broadcast %f : f32 to vector<8xf32> + %1 = constant 0: index + %2 = load %A[%1] : memref + %3 = vector.insert %2, %0[0] : f32 into vector<8xf32> + %4 = constant 1: index + %5 = load %A[%4] : memref + %6 = vector.insert %5, %3[1] : f32 into vector<8xf32> + %7 = constant 2: index + %8 = load %A[%7] : memref + %9 = vector.insert %8, %6[2] : f32 into vector<8xf32> + %10 = constant 3: index + %11 = load %A[%10] : memref + %12 = vector.insert %11, %9[3] : f32 into vector<8xf32> + %13 = constant 4: index + %14 = load %A[%13] : memref + %15 = vector.insert %14, %12[4] : f32 into vector<8xf32> + %16 = constant 5: index + %17 = load %A[%16] : memref + %18 = vector.insert %17, %15[5] : f32 into vector<8xf32> + %19 = constant 6: index + %20 = load %A[%19] : memref + %21 = vector.insert %20, %18[6] : f32 into vector<8xf32> + %22 = constant 7: index + %23 = load %A[%22] : memref + %24 = vector.insert %23, %21[7] : f32 into vector<8xf32> + vector.print %24 : vector<8xf32> + return +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c8 = constant 8: index + %A = alloc(%c8) : memref + scf.for %i = %c0 to %c8 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up idx vector. + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %i2 = constant 2: i32 + %i3 = constant 3: i32 + %i4 = constant 4: i32 + %i5 = constant 5: i32 + %i6 = constant 6: i32 + %i7 = constant 7: i32 + %0 = vector.broadcast %i7 : i32 to vector<8xi32> + %1 = vector.insert %i0, %0[1] : i32 into vector<8xi32> + %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32> + %3 = vector.insert %i6, %2[3] : i32 into vector<8xi32> + %4 = vector.insert %i2, %3[4] : i32 into vector<8xi32> + %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32> + %6 = vector.insert %i5, %5[6] : i32 into vector<8xi32> + %idx = vector.insert %i3, %6[7] : i32 into vector<8xi32> + + // Set up value vector. + %f0 = constant 0.0: f32 + %f1 = constant 1.0: f32 + %f2 = constant 2.0: f32 + %f3 = constant 3.0: f32 + %f4 = constant 4.0: f32 + %f5 = constant 5.0: f32 + %f6 = constant 6.0: f32 + %f7 = constant 7.0: f32 + %7 = vector.broadcast %f0 : f32 to vector<8xf32> + %8 = vector.insert %f1, %7[1] : f32 into vector<8xf32> + %9 = vector.insert %f2, %8[2] : f32 into vector<8xf32> + %10 = vector.insert %f3, %9[3] : f32 into vector<8xf32> + %11 = vector.insert %f4, %10[4] : f32 into vector<8xf32> + %12 = vector.insert %f5, %11[5] : f32 into vector<8xf32> + %13 = vector.insert %f6, %12[6] : f32 into vector<8xf32> + %val = vector.insert %f7, %13[7] : f32 into vector<8xf32> + + // Set up masks. + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<8xi1> + %some = vector.constant_mask [4] : vector<8xi1> + %more = vector.insert %t, %some[7] : i1 into vector<8xi1> + %all = vector.constant_mask [8] : vector<8xi1> + + // + // Scatter tests. + // + + vector.print %idx : vector<8xi32> + // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 ) + + call @printmem(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + + call @scatter8(%A, %idx, %none, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + + call @scatter8(%A, %idx, %some, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 ) + + call @scatter8(%A, %idx, %more, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 ) + + call @scatter8(%A, %idx, %all, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 ) + + return +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -34,14 +34,6 @@ using namespace mlir; using namespace mlir::vector; -template -static LLVM::LLVMType getPtrToElementType(T containerType, - LLVMTypeConverter &typeConverter) { - return typeConverter.convertType(containerType.getElementType()) - .template cast() - .getPointerTo(); -} - // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); @@ -124,11 +116,12 @@ return res; } -template -LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, - TransferOp xferOp, unsigned &align) { +// Helper that returns data layout alignment of an operation with memref. +template +LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, + unsigned &align) { Type elementTy = - typeConverter.convertType(xferOp.getMemRefType().getElementType()); + typeConverter.convertType(op.getMemRefType().getElementType()); if (!elementTy) return failure(); @@ -138,13 +131,54 @@ return success(); } +// Helper that returns vector of pointers given a base and an index vector. +LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + Value memref, Value indices, MemRefType memRefType, + VectorType vType, Type iType, Value &ptrs) { + // Inspect stride and offset structure. + // + // TODO: flat memory only for now, generalize + // + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(memRefType, strides, offset); + if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || + offset != 0 || memRefType.getMemorySpace() != 0) + return failure(); + + // Base pointer. + MemRefDescriptor memRefDescriptor(memref); + Value base = memRefDescriptor.alignedPtr(rewriter, loc); + + // Create a vector of pointers from base and indices. + // + // TODO: this step serializes the address computations unfortunately, + // ideally we would like to add splat(base) + index_vector + // in SIMD form, but this does not match well with current + // constraints of the standard and vector dialect.... + // + int64_t size = vType.getDimSize(0); + auto pType = memRefDescriptor.getElementType(); + auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size); + auto idxType = typeConverter.convertType(iType); + ptrs = rewriter.create(loc, ptrsType); + for (int64_t i = 0; i < size; i++) { + Value off = + extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i); + Value ptr = rewriter.create(loc, pType, base, off); + ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i); + } + return success(); +} + static LogicalResult replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); @@ -165,7 +199,7 @@ return failure(); unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp( @@ -180,7 +214,7 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, @@ -194,7 +228,7 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); @@ -259,6 +293,83 @@ } }; +/// Conversion pattern for a vector.gather. +class VectorGatherOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorGatherOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto gather = cast(op); + auto adaptor = vector::GatherOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, gather, align))) + return failure(); + + // Get index ptrs. + VectorType vType = gather.getResultVectorType(); + Type iType = gather.getIndicesVectorType().getElementType(); + Value ptrs; + if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), + adaptor.indices(), gather.getMemRefType(), vType, + iType, ptrs))) + return failure(); + + // Replace with the gather intrinsic. + ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({}) + : adaptor.pass_thru(); + rewriter.replaceOpWithNewOp( + gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v, + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + +/// Conversion pattern for a vector.scatter. +class VectorScatterOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorScatterOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto scatter = cast(op); + auto adaptor = vector::ScatterOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, scatter, align))) + return failure(); + + // Get index ptrs. + VectorType vType = scatter.getValueVectorType(); + Type iType = scatter.getIndicesVectorType().getElementType(); + Value ptrs; + if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), + adaptor.indices(), scatter.getMemRefType(), vType, + iType, ptrs))) + return failure(); + + // Replace with the scatter intrinsic. + rewriter.replaceOpWithNewOp( + scatter, adaptor.value(), ptrs, adaptor.mask(), + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + +/// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1173,7 +1284,9 @@ VectorPrintOpConversion, VectorTransferConversion, VectorTransferConversion, - VectorTypeCastOpConversion>(ctx, converter); + VectorTypeCastOpConversion, + VectorGatherOpConversion, + VectorScatterOpConversion>(ctx, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1858,6 +1858,49 @@ return SmallVector{s.begin(), s.end()}; } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(GatherOp op) { + VectorType indicesVType = op.getIndicesVectorType(); + VectorType maskVType = op.getMaskVectorType(); + VectorType resVType = op.getResultVectorType(); + + if (resVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and result element type should match"); + + if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) + return op.emitOpError("expected result dim to match indices dim"); + if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected result dim to match mask dim"); + if (llvm::size(op.pass_thru()) != 0) { + VectorType passVType = op.getPassThruVectorType(); + if (resVType != passVType) + return op.emitOpError("expected pass_thru of same type as result type"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ScatterOp op) { + VectorType indicesVType = op.getIndicesVectorType(); + VectorType maskVType = op.getMaskVectorType(); + VectorType valueVType = op.getValueVectorType(); + + if (valueVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and value element type should match"); + + if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) + return op.emitOpError("expected value dim to match indices dim"); + if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected value dim to match mask dim"); + return success(); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -969,3 +969,21 @@ // CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : // CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>"> // CHECK: llvm.return %[[T]] : !llvm<"<16 x float>"> + +func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { + %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + return %0 : vector<3xf32> +} + +// CHECK-LABEL: func @gather_op +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.return %[[G]] : !llvm<"<3 x float>"> + +func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { + vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref + return +} + +// CHECK-LABEL: func @scatter_op +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>"> +// CHECK: llvm.return diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1177,3 +1177,66 @@ // expected-error@+1 {{expects operand to be a memref with no layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> } + +// ----- + +func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op base and result element type should match}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32> +} + +// ----- + +func @gather_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<17xi32>, vector<16xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>) { + // expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<17xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { + // expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}} + %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32> +} + +// ----- + +func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op base and value element type should match}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref +} + +// ----- + +func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) { + // expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref +} + +// ----- + +func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}} + vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref +} + +// ----- + +func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -368,3 +368,14 @@ // CHECK: return %[[X]] : vector<16xi32> return %0 : vector<16xi32> } + +// CHECK-LABEL: @gather_and_scatter +func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> + // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + %1 = vector.gather %base, %indices, %mask, %0 : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + return +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -206,6 +206,20 @@ llvm.return } +// CHECK-LABEL: @masked_gather_scatter_intrinsics +llvm.func @masked_gather_scatter_intrinsics(%M: !llvm<"<7 x float*>">, %mask: !llvm<"<7 x i1>">) { + // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef) + %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} : + (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>"> + // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %b = llvm.intr.masked.gather %M, %mask, %a { alignment = 1: i32} : + (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> !llvm<"<7 x float>"> + // CHECK: call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %{{.*}}, <7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}) + llvm.intr.masked.scatter %b, %M, %mask { alignment = 1: i32} : + !llvm<"<7 x float>">, !llvm<"<7 x i1>"> into !llvm<"<7 x float*>"> + llvm.return +} + // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) { // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})