diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1760,10 +1760,10 @@ Vector_Op<"gather">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$pass_thru)>, - Results<(outs VectorOfRank<[1]>:$result)> { + VectorOf<[AnyInteger, Index]>:$index_vec, + VectorOf<[I1]>:$mask, + AnyVector:$pass_thru)>, + Results<(outs AnyVector:$result)> { let summary = [{ gathers elements from memory or ranked tensor into a vector as defined by an @@ -1772,10 +1772,10 @@ let description = [{ The gather operation gathers elements from memory or ranked tensor into a - 1-D vector as defined by a base with indices and an additional 1-D index - vector, but only if the corresponding bit is set in a 1-D mask vector. - Otherwise, the element is taken from a 1-D pass-through vector. Informally - the semantics are: + n-D vector as defined by a base with indices and an additional n-D index + vector (each index is a 1-D offset on the base), but only if the + corresponding bit is set in a n-D mask vector. Otherwise, the element is + taken from a n-D pass-through vector. Informally the semantics are: ``` result[0] := mask[0] ? base[index[0]] : pass_thru[0] result[1] := mask[1] ? base[index[1]] : pass_thru[1] @@ -1785,15 +1785,13 @@ The gather operation can be used directly where applicable, or can be used during progressively lowering to bring other memory operations closer to - hardware ISA support for a gather. The semantics of the operation closely - correspond to those of the `llvm.masked.gather` - [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics). + hardware ISA support for a gather. Examples: ```mlir %0 = vector.gather %base[%c0][%v], %mask, %pass_thru - : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -82,14 +82,6 @@ std::function createOperand, ConversionPatternRewriter &rewriter) { auto resultNDVectorType = op->getResult(0).getType().cast(); - - SmallVector operand1DVectorTypes; - for (Value operand : op->getOperands()) { - auto operandNDVectorType = operand.getType().cast(); - auto operandTypeInfo = - extractNDVectorTypeInfo(operandNDVectorType, typeConverter); - operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); - } auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -91,24 +91,28 @@ return success(); } -// Add an index vector component to a base pointer. This almost always succeeds -// unless the last stride is non-unit or the memory space is not zero. -static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, - Location loc, Value memref, Value base, - Value index, MemRefType memRefType, - VectorType vType, Value &ptrs) { +// Check if the last stride is non-unit or the memory space is not zero. +static LogicalResult isMemRefTypeSupported(MemRefType memRefType) { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides) || strides.back() != 1 || memRefType.getMemorySpaceAsInt() != 0) return failure(); - auto pType = MemRefDescriptor(memref).getElementPtrType(); - auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); - ptrs = rewriter.create(loc, ptrsType, base, index); return success(); } +// Add an index vector component to a base pointer. +static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, + MemRefType memRefType, Value llvmMemref, Value base, + Value index, uint64_t vLen) { + assert(succeeded(isMemRefTypeSupported(memRefType)) && + "unsupported memref type"); + auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); + auto ptrsType = LLVM::getFixedVectorType(pType, vLen); + return rewriter.create(loc, ptrsType, base, index); +} + // Casts a strided element pointer to a vector pointer. The vector pointer // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, @@ -257,29 +261,53 @@ LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = gather->getLoc(); MemRefType memRefType = gather.getBaseType().dyn_cast(); assert(memRefType && "The base should be bufferized"); + if (failed(isMemRefTypeSupported(memRefType))) + return failure(); + + auto loc = gather->getLoc(); + // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); - // Resolve address. - Value ptrs; - VectorType vType = gather.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, - adaptor.getIndexVec(), memRefType, vType, ptrs))) - return failure(); + Value base = adaptor.getBase(); + + auto llvmNDVectorTy = adaptor.getIndexVec().getType(); + // Handle the simple case of 1-D vector. + if (!llvmNDVectorTy.isa()) { + auto vType = gather.getVectorType(); + // Resolve address. + Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr, + adaptor.getIndexVec(), + /*vLen=*/vType.getDimSize(0)); + // Replace with the gather intrinsic. + rewriter.replaceOpWithNewOp( + gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), + adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); + return success(); + } - // Replace with the gather intrinsic. - rewriter.replaceOpWithNewOp( - gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), - adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); - return success(); + auto callback = [align, memRefType, base, ptr, loc, &rewriter]( + Type llvm1DVectorTy, ValueRange vectorOperands) { + // Resolve address. + Value ptrs = getIndexedPtrs( + rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0], + LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); + // Create the gather intrinsic. + return rewriter.create( + loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], + /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); + }; + ValueRange vectorOperands = {adaptor.getIndexVec(), adaptor.getMask(), + adaptor.getPassThru()}; + return LLVM::detail::handleMultidimensionalVectors( + gather, vectorOperands, *getTypeConverter(), callback, rewriter); } }; @@ -295,19 +323,21 @@ auto loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); + if (failed(isMemRefTypeSupported(memRefType))) + return failure(); + // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. - Value ptrs; VectorType vType = scatter.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, - adaptor.getIndexVec(), memRefType, vType, ptrs))) - return failure(); + Value ptrs = + getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr, + adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -49,11 +49,11 @@ Unknown = 2, }; -/// Helper method to classify a 1-D mask value. Currently, the method +/// Helper method to classify a mask value. Currently, the method /// looks "under the hood" of a constant value with dense attributes /// and a constant mask operation (since the client may be called at /// various stages during progressive lowering). -static MaskFormat get1DMaskFormat(Value mask) { +static MaskFormat getMaskFormat(Value mask) { if (auto c = mask.getDefiningOp()) { // Inspect constant dense values. We count up for bits that // are set, count down for bits that are cleared, and bail @@ -77,12 +77,20 @@ // dimension size, all bits are set. If the index is zero // or less, no bits are set. ArrayAttr masks = m.getMaskDimSizes(); - assert(masks.size() == 1); - int64_t i = masks[0].cast().getInt(); - int64_t u = m.getType().getDimSize(0); - if (i >= u) + auto shape = m.getType().getShape(); + bool allTrue = true; + bool allFalse = true; + for (auto pair : llvm::zip(masks, shape)) { + int64_t i = std::get<0>(pair).cast().getInt(); + int64_t u = std::get<1>(pair); + if (i < u) + allTrue = false; + if (i > 0) + allFalse = false; + } + if (allTrue) return MaskFormat::AllTrue; - if (i <= 0) + if (allFalse) return MaskFormat::AllFalse; } return MaskFormat::Unknown; @@ -3980,7 +3988,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(load.getMask())) { + switch (getMaskFormat(load.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( load, load.getType(), load.getBase(), load.getIndices()); @@ -4031,7 +4039,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(store.getMask())) { + switch (getMaskFormat(store.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getBase(), store.getIndices()); @@ -4074,9 +4082,9 @@ return emitOpError("base and result element type should match"); if (llvm::size(getIndices()) != baseType.getRank()) return emitOpError("requires ") << baseType.getRank() << " indices"; - if (resVType.getDimSize(0) != indVType.getDimSize(0)) + if (resVType.getShape() != indVType.getShape()) return emitOpError("expected result dim to match indices dim"); - if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + if (resVType.getShape() != maskVType.getShape()) return emitOpError("expected result dim to match mask dim"); if (resVType != getPassThruVectorType()) return emitOpError("expected pass_thru of same type as result type"); @@ -4089,7 +4097,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(gather.getMask())) { + switch (getMaskFormat(gather.getMask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: @@ -4135,7 +4143,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(scatter.getMask())) { + switch (getMaskFormat(scatter.getMask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: @@ -4181,7 +4189,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(expand.getMask())) { + switch (getMaskFormat(expand.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( expand, expand.getType(), expand.getBase(), expand.getIndices()); @@ -4226,7 +4234,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(compress.getMask())) { + switch (getMaskFormat(compress.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( compress, compress.getValueToStore(), compress.getBase(), diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1918,6 +1918,56 @@ // ----- +func.func @gather_op_multi_dims(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %1 : vector<2x3xf32> +} + +// CHECK-LABEL: func @gather_op_multi_dims +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>> +// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>> +// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>> +// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>> +// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> + +// ----- + +func.func @gather_op_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = arith.constant 0: index + %1 = vector.constant_mask [1, 2] : vector<2x3xi1> + %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %2 : vector<2x3xf32> +} + +// CHECK-LABEL: func @gather_op_with_mask +// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + +// ----- + +func.func @gather_op_with_zero_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = arith.constant 0: index + %1 = vector.constant_mask [0, 0] : vector<2x3xi1> + %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %2 : vector<2x3xf32> +} + +// CHECK-LABEL: func @gather_op_with_zero_mask +// CHECK-SAME: (%{{.*}}: memref, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>) +// CHECK-NOT: %{{.*}} = llvm.intr.masked.gather +// CHECK: return %[[S]] : vector<2x3xf32> + +// ----- + func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = arith.constant 3 : index %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1305,7 +1305,7 @@ func.func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}} + // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}} %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -677,6 +677,14 @@ return %0 : vector<16xf32> } +// CHECK-LABEL: @gather_multi_dims +func.func @gather_multi_dims(%base: tensor, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> { + %c0 = arith.constant 0 : index + // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32> + return %0 : vector<2x16xf32> +} + // CHECK-LABEL: @expand_and_compress func.func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index