diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1037,6 +1037,16 @@ "type($value) `,` type($mask) `into` type($ptrs)"; } +/// Create a call to Masked Expand Load intrinsic. +def LLVM_masked_expandload + : LLVM_IntrOp<"masked.expandload", [0], [], [], 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +/// Create a call to Masked Compress Store intrinsic. +def LLVM_masked_compressstore + : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + // // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1158,7 +1158,7 @@ Variadic>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { - let summary = "gathers elements from memory into a vector as defined by an index vector"; + let summary = "gathers elements from memory into a vector as defined by an index vector and mask"; let description = [{ The gather operation gathers elements from memory into a 1-D vector as @@ -1186,7 +1186,6 @@ %g = vector.gather %base, %indices, %mask, %pass_thru : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> ``` - }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -1217,7 +1216,7 @@ VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$value)> { - let summary = "scatters elements from a vector into memory as defined by an index vector"; + let summary = "scatters elements from a vector into memory as defined by an index vector and mask"; let description = [{ The scatter operation scatters elements from a 1-D vector into memory as @@ -1265,6 +1264,108 @@ "type($indices) `,` type($mask) `,` type($value) `into` type($base)"; } +def Vector_ExpandLoadOp : + Vector_Op<"expandload">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$pass_thru)>, + Results<(outs VectorOfRank<[1]>:$result)> { + + let summary = "reads elements from memory and spreads them into a vector as defined by a mask"; + + let description = [{ + The expand load reads elements from memory into a 1-D vector as defined + by a base and a 1-D mask vector. When the mask is set, the next element + is read from memory. Otherwise, the corresponding element is taken from + a 1-D pass-through vector. Informally the semantics are: + ``` + index = base + result[0] := mask[0] ? MEM[index++] : pass_thru[0] + result[1] := mask[1] ? MEM[index++] : pass_thru[1] + etc. + ``` + Note that the index increment is done conditionally. + + The expand load can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a gather. The semantics of the operation closely + correspond to those of the `llvm.masked.expandload` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics). + + Example: + + ```mlir + %0 = vector.expandload %base, %mask, %pass_thru + : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getPassThruVectorType() { + return pass_thru().getType().cast(); + } + VectorType getResultVectorType() { + return result().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " + "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; +} + +def Vector_CompressStoreOp : + Vector_Op<"compressstore">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$value)> { + + let summary = "writes elements selectively from a vector as defined by a mask"; + + let description = [{ + The compress store operation writes elements from a 1-D vector into memory + as defined by 1-D mask vector. When the mask is set, the corresponding + element from the vector is written next to memory. Otherwise, no action + is taken for the element. Informally the semantics are: + ``` + index = base + if (mask[0]) MEM[index++] = value[0] + if (mask[1]) MEM[index++] = value[1] + etc. + ``` + Note that the index increment is done conditionally. + + The compress store can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a scatter. The semantics of the operation closely + correspond to those of the `llvm.masked.compressstore` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics). + + Example: + + ```mlir + vector.compressstore %base, %mask, %value + : memref, vector<8xi1>, vector<8xf32> + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getValueVectorType() { + return value().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " + "type($base) `,` type($mask) `,` type($value)"; +} + def Vector_ShapeCastOp : Vector_Op<"shape_cast", [NoSideEffect]>, Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>, diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir @@ -0,0 +1,90 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @compress16(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + vector.compressstore %base, %mask, %value + : memref, vector<16xi1>, vector<16xf32> + return +} + +func @printmem16(%A: memref) { + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %z = constant 0.0: f32 + %m = vector.broadcast %z : f32 to vector<16xf32> + %mem = scf.for %i = %c0 to %c16 step %c1 + iter_args(%m_iter = %m) -> (vector<16xf32>) { + %c = load %A[%i] : memref + %i32 = index_cast %i : index to i32 + %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + scf.yield %m_new : vector<16xf32> + } + vector.print %mem : vector<16xf32> + return +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %A = alloc(%c16) : memref + %z = constant 0.0: f32 + %v = vector.broadcast %z : f32 to vector<16xf32> + %value = scf.for %i = %c0 to %c16 step %c1 + iter_args(%v_iter = %v) -> (vector<16xf32>) { + store %z, %A[%i] : memref + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + scf.yield %v_new : vector<16xf32> + } + + // Set up masks. + %f = constant 0: i1 + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<16xi1> + %all = vector.constant_mask [16] : vector<16xi1> + %some1 = vector.constant_mask [4] : vector<16xi1> + %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1> + %1 = vector.insert %t, %0[7] : i1 into vector<16xi1> + %2 = vector.insert %t, %1[11] : i1 into vector<16xi1> + %3 = vector.insert %t, %2[13] : i1 into vector<16xi1> + %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1> + %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1> + + // + // Expanding load tests. + // + + call @compress16(%A, %none, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + + call @compress16(%A, %all, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + call @compress16(%A, %some3, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + call @compress16(%A, %some2, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + call @compress16(%A, %some1, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @expand16(%base: memref, + %mask: vector<16xi1>, + %pass_thru: vector<16xf32>) -> vector<16xf32> { + %e = vector.expandload %base, %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %e : vector<16xf32> +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %A = alloc(%c16) : memref + scf.for %i = %c0 to %c16 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up pass thru vector. + %u = constant -7.0: f32 + %v = constant 7.7: f32 + %pass = vector.broadcast %u : f32 to vector<16xf32> + + // Set up masks. + %f = constant 0: i1 + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<16xi1> + %all = vector.constant_mask [16] : vector<16xi1> + %some1 = vector.constant_mask [4] : vector<16xi1> + %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1> + %1 = vector.insert %t, %0[7] : i1 into vector<16xi1> + %2 = vector.insert %t, %1[11] : i1 into vector<16xi1> + %3 = vector.insert %t, %2[13] : i1 into vector<16xi1> + %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1> + %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1> + + // + // Expanding load tests. + // + + %e1 = call @expand16(%A, %none, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e1 : vector<16xf32> + // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 ) + + %e2 = call @expand16(%A, %all, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e2 : vector<16xf32> + // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + %e3 = call @expand16(%A, %some1, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e3 : vector<16xf32> + // CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 ) + + %e4 = call @expand16(%A, %some2, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e4 : vector<16xf32> + // CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 ) + + %e5 = call @expand16(%A, %some3, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e5 : vector<16xf32> + // CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 ) + + %4 = vector.insert %v, %pass[1] : f32 into vector<16xf32> + %5 = vector.insert %v, %4[2] : f32 into vector<16xf32> + %alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32> + %e6 = call @expand16(%A, %some3, %alt_pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e6 : vector<16xf32> + // CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 ) + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir @@ -11,34 +11,20 @@ return } -func @printmem(%A: memref) { - %f = constant 0.0: f32 - %0 = vector.broadcast %f : f32 to vector<8xf32> - %1 = constant 0: index - %2 = load %A[%1] : memref - %3 = vector.insert %2, %0[0] : f32 into vector<8xf32> - %4 = constant 1: index - %5 = load %A[%4] : memref - %6 = vector.insert %5, %3[1] : f32 into vector<8xf32> - %7 = constant 2: index - %8 = load %A[%7] : memref - %9 = vector.insert %8, %6[2] : f32 into vector<8xf32> - %10 = constant 3: index - %11 = load %A[%10] : memref - %12 = vector.insert %11, %9[3] : f32 into vector<8xf32> - %13 = constant 4: index - %14 = load %A[%13] : memref - %15 = vector.insert %14, %12[4] : f32 into vector<8xf32> - %16 = constant 5: index - %17 = load %A[%16] : memref - %18 = vector.insert %17, %15[5] : f32 into vector<8xf32> - %19 = constant 6: index - %20 = load %A[%19] : memref - %21 = vector.insert %20, %18[6] : f32 into vector<8xf32> - %22 = constant 7: index - %23 = load %A[%22] : memref - %24 = vector.insert %23, %21[7] : f32 into vector<8xf32> - vector.print %24 : vector<8xf32> +func @printmem8(%A: memref) { + %c0 = constant 0: index + %c1 = constant 1: index + %c8 = constant 8: index + %z = constant 0.0: f32 + %m = vector.broadcast %z : f32 to vector<8xf32> + %mem = scf.for %i = %c0 to %c8 step %c1 + iter_args(%m_iter = %m) -> (vector<8xf32>) { + %c = load %A[%i] : memref + %i32 = index_cast %i : index to i32 + %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32> + scf.yield %m_new : vector<8xf32> + } + vector.print %mem : vector<8xf32> return } @@ -104,31 +90,27 @@ vector.print %idx : vector<8xi32> // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 ) - call @printmem(%A) : (memref) -> () + call @printmem8(%A) : (memref) -> () // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) call @scatter8(%A, %idx, %none, %val) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () - - call @printmem(%A) : (memref) -> () + call @printmem8(%A) : (memref) -> () // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) call @scatter8(%A, %idx, %some, %val) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () - - call @printmem(%A) : (memref) -> () + call @printmem8(%A) : (memref) -> () // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 ) call @scatter8(%A, %idx, %more, %val) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () - - call @printmem(%A) : (memref) -> () + call @printmem8(%A) : (memref) -> () // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 ) call @scatter8(%A, %idx, %all, %val) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () - - call @printmem(%A) : (memref) -> () + call @printmem8(%A) : (memref) -> () // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 ) return diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -131,11 +131,9 @@ return success(); } -// Helper that returns vector of pointers given a base and an index vector. -LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - Value memref, Value indices, MemRefType memRefType, - VectorType vType, Type iType, Value &ptrs) { +// Helper that returns the base address of a memref. +LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, + Value memref, MemRefType memRefType, Value &base) { // Inspect stride and offset structure. // // TODO: flat memory only for now, generalize @@ -146,13 +144,31 @@ if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || offset != 0 || memRefType.getMemorySpace() != 0) return failure(); + base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); + return success(); +} + +// Helper that returns a pointer given a memref base. +LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, + Value memref, MemRefType memRefType, Value &ptr) { + Value base; + if (failed(getBase(rewriter, loc, memref, memRefType, base))) + return failure(); + auto pType = MemRefDescriptor(memref).getElementType(); + ptr = rewriter.create(loc, pType, base); + return success(); +} - // Create a vector of pointers from base and indices. - MemRefDescriptor memRefDescriptor(memref); - Value base = memRefDescriptor.alignedPtr(rewriter, loc); - int64_t size = vType.getDimSize(0); - auto pType = memRefDescriptor.getElementType(); - auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size); +// Helper that returns vector of pointers given a memref base and an index +// vector. +LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, + Value memref, Value indices, MemRefType memRefType, + VectorType vType, Type iType, Value &ptrs) { + Value base; + if (failed(getBase(rewriter, loc, memref, memRefType, base))) + return failure(); + auto pType = MemRefDescriptor(memref).getElementType(); + auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); ptrs = rewriter.create(loc, ptrsType, base, indices); return success(); } @@ -302,9 +318,8 @@ VectorType vType = gather.getResultVectorType(); Type iType = gather.getIndicesVectorType().getElementType(); Value ptrs; - if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), - adaptor.indices(), gather.getMemRefType(), vType, - iType, ptrs))) + if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), + gather.getMemRefType(), vType, iType, ptrs))) return failure(); // Replace with the gather intrinsic. @@ -341,9 +356,8 @@ VectorType vType = scatter.getValueVectorType(); Type iType = scatter.getIndicesVectorType().getElementType(); Value ptrs; - if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), - adaptor.indices(), scatter.getMemRefType(), vType, - iType, ptrs))) + if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), + scatter.getMemRefType(), vType, iType, ptrs))) return failure(); // Replace with the scatter intrinsic. @@ -354,6 +368,60 @@ } }; +/// Conversion pattern for a vector.expandload. +class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorExpandLoadOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto expand = cast(op); + auto adaptor = vector::ExpandLoadOpAdaptor(operands); + + Value ptr; + if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), + ptr))) + return failure(); + + auto vType = expand.getResultVectorType(); + rewriter.replaceOpWithNewOp( + op, typeConverter.convertType(vType), ptr, adaptor.mask(), + adaptor.pass_thru()); + return success(); + } +}; + +/// Conversion pattern for a vector.compressstore. +class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorCompressStoreOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), + context, typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto compress = cast(op); + auto adaptor = vector::CompressStoreOpAdaptor(operands); + + Value ptr; + if (failed(getBasePtr(rewriter, loc, adaptor.base(), + compress.getMemRefType(), ptr))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.value(), ptr, adaptor.mask()); + return success(); + } +}; + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertToLLVMPattern { public: @@ -1271,7 +1339,9 @@ VectorTransferConversion, VectorTypeCastOpConversion, VectorGatherOpConversion, - VectorScatterOpConversion>(ctx, converter); + VectorScatterOpConversion, + VectorExpandLoadOpConversion, + VectorCompressStoreOpConversion>(ctx, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1901,6 +1901,41 @@ return success(); } +//===----------------------------------------------------------------------===// +// ExpandLoadOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ExpandLoadOp op) { + VectorType maskVType = op.getMaskVectorType(); + VectorType passVType = op.getPassThruVectorType(); + VectorType resVType = op.getResultVectorType(); + + if (resVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and result element type should match"); + + if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected result dim to match mask dim"); + if (resVType != passVType) + return op.emitOpError("expected pass_thru of same type as result type"); + return success(); +} + +//===----------------------------------------------------------------------===// +// CompressStoreOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(CompressStoreOp op) { + VectorType maskVType = op.getMaskVectorType(); + VectorType valueVType = op.getValueVectorType(); + + if (valueVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and value element type should match"); + + if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected value dim to match mask dim"); + return success(); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -989,3 +989,23 @@ // CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm<"float*">, !llvm<"<3 x i32>">) -> !llvm<"<3 x float*>"> // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>"> // CHECK: llvm.return + +func @expand_load_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> { + %0 = vector.expandload %arg0, %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> into vector<11xf32> + return %0 : vector<11xf32> +} + +// CHECK-LABEL: func @expand_load_op +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm<"float*">) -> !llvm<"float*"> +// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm<"float*">, !llvm<"<11 x i1>">, !llvm<"<11 x float>">) -> !llvm<"<11 x float>"> +// CHECK: llvm.return %[[E]] : !llvm<"<11 x float>"> + +func @compress_store_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) { + vector.compressstore %arg0, %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> + return +} + +// CHECK-LABEL: func @compress_store_op +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm<"float*">) -> !llvm<"float*"> +// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm<"<11 x float>">, !llvm<"float*">, !llvm<"<11 x i1>">) -> () +// CHECK: llvm.return diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1240,3 +1240,38 @@ // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref } + +// ----- + +func @expand_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + // expected-error@+1 {{'vector.expandload' op base and result element type should match}} + %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func @expand_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) { + // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}} + %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<17xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func @expand_pass_thru_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) { + // expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}} + %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<16xi1>, vector<17xf32> into vector<16xf32> +} + +// ----- + +func @compress_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.compressstore' op base and value element type should match}} + vector.compressstore %base, %mask, %value : memref, vector<16xi1>, vector<16xf32> +} + +// ----- + +func @compress_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} + vector.compressstore %base, %mask, %value : memref, vector<17xi1>, vector<16xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -379,3 +379,12 @@ vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref return } + +// CHECK-LABEL: @expand_and_compress +func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { + // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base, %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + vector.compressstore %base, %mask, %0 : memref, vector<16xi1>, vector<16xf32> + return +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -192,8 +192,8 @@ llvm.return } -// CHECK-LABEL: @masked_intrinsics -llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">) { +// CHECK-LABEL: @masked_load_store_intrinsics +llvm.func @masked_load_store_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">) { // CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef) %a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} : (!llvm<"<7 x float>*">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>"> @@ -220,6 +220,17 @@ llvm.return } +// CHECK-LABEL: @masked_expand_compress_intrinsics +llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm<"float*">, %mask: !llvm<"<7 x i1>">, %passthru: !llvm<"<7 x float>">) { + // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) + : (!llvm<"float*">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> (!llvm<"<7 x float>">) + // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}}) + "llvm.intr.masked.compressstore"(%0, %ptr, %mask) + : (!llvm<"<7 x float>">, !llvm<"float*">, !llvm<"<7 x i1>">) -> () + llvm.return +} + // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) { // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})